jowilke77 commited on
Commit
a0f9301
·
verified ·
1 Parent(s): bffe28d

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +26 -10
inference.py CHANGED
@@ -1,25 +1,34 @@
1
  import json
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
 
4
 
5
- with open("config.json", "r") as f:
6
  cfg = json.load(f)
7
 
8
- BASE_MODEL = cfg["base_model"]
9
-
10
- tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
11
- model = AutoModelForCausalLM.from_pretrained(BASE_MODEL)
 
 
12
 
13
- with open("prompt.txt", "r") as f:
14
  SYSTEM_PROMPT = f.read().strip()
15
 
16
  def chat(user_input):
17
- prompt = f"{SYSTEM_PROMPT}\n\nUser: {user_input}\nBrad AI:"
18
- inputs = tokenizer(prompt, return_tensors="pt")
 
 
 
 
 
 
 
19
 
20
  with torch.no_grad():
21
  output = model.generate(
22
- **inputs,
23
  max_new_tokens=cfg["max_new_tokens"],
24
  temperature=cfg["temperature"],
25
  top_p=cfg["top_p"],
@@ -27,3 +36,10 @@ def chat(user_input):
27
  )
28
 
29
  return tokenizer.decode(output[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
1
  import json
 
2
  import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
 
5
+ with open("config.json") as f:
6
  cfg = json.load(f)
7
 
8
+ tokenizer = AutoTokenizer.from_pretrained(cfg["base_model"])
9
+ model = AutoModelForCausalLM.from_pretrained(
10
+ cfg["base_model"],
11
+ torch_dtype=torch.float32,
12
+ device_map="cpu"
13
+ )
14
 
15
+ with open("prompt.txt") as f:
16
  SYSTEM_PROMPT = f.read().strip()
17
 
18
  def chat(user_input):
19
+ messages = [
20
+ {"role": "system", "content": SYSTEM_PROMPT},
21
+ {"role": "user", "content": user_input}
22
+ ]
23
+
24
+ input_ids = tokenizer.apply_chat_template(
25
+ messages,
26
+ return_tensors="pt"
27
+ )
28
 
29
  with torch.no_grad():
30
  output = model.generate(
31
+ input_ids,
32
  max_new_tokens=cfg["max_new_tokens"],
33
  temperature=cfg["temperature"],
34
  top_p=cfg["top_p"],
 
36
  )
37
 
38
  return tokenizer.decode(output[0], skip_special_tokens=True)
39
+
40
+ if __name__ == "__main__":
41
+ while True:
42
+ msg = input("You: ")
43
+ if msg.lower() in ("exit", "quit"):
44
+ break
45
+ print(chat(msg))