jowilke77 commited on
Commit
abddbd7
·
verified ·
1 Parent(s): 3c95cd1

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +13 -16
inference.py CHANGED
@@ -1,32 +1,29 @@
 
1
  from transformers import AutoTokenizer, AutoModelForCausalLM
2
  import torch
3
 
4
- MODEL_NAME = "distilgpt2"
 
5
 
6
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
7
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
 
 
8
 
9
  with open("prompt.txt", "r") as f:
10
  SYSTEM_PROMPT = f.read().strip()
11
 
12
- def generate(user_input):
13
  prompt = f"{SYSTEM_PROMPT}\n\nUser: {user_input}\nBrad AI:"
14
  inputs = tokenizer(prompt, return_tensors="pt")
15
 
16
  with torch.no_grad():
17
- outputs = model.generate(
18
  **inputs,
19
- max_new_tokens=150,
20
- temperature=0.7,
21
- top_p=0.9,
22
  do_sample=True
23
  )
24
 
25
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
26
-
27
- if __name__ == "__main__":
28
- while True:
29
- user = input("You: ")
30
- if user.lower() in ["exit", "quit"]:
31
- break
32
- print(generate(user))
 
1
+ import json
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
 
5
+ with open("config.json", "r") as f:
6
+ cfg = json.load(f)
7
 
8
+ BASE_MODEL = cfg["base_model"]
9
+
10
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
11
+ model = AutoModelForCausalLM.from_pretrained(BASE_MODEL)
12
 
13
  with open("prompt.txt", "r") as f:
14
  SYSTEM_PROMPT = f.read().strip()
15
 
16
+ def chat(user_input):
17
  prompt = f"{SYSTEM_PROMPT}\n\nUser: {user_input}\nBrad AI:"
18
  inputs = tokenizer(prompt, return_tensors="pt")
19
 
20
  with torch.no_grad():
21
+ output = model.generate(
22
  **inputs,
23
+ max_new_tokens=cfg["max_new_tokens"],
24
+ temperature=cfg["temperature"],
25
+ top_p=cfg["top_p"],
26
  do_sample=True
27
  )
28
 
29
+ return tokenizer.decode(output[0], skip_special_tokens=True)