prelington commited on
Commit
ab9ffec
·
verified ·
1 Parent(s): 2eaa513

Create chat.py

Browse files
Files changed (1) hide show
  1. chat.py +21 -0
chat.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # chat.py
2
+ import torch
3
+ from model_loader import load_model
4
+ from config import MAX_TOKENS, TEMPERATURE
5
+
6
+ tokenizer, model = load_model()
7
+
8
+ def generate_response(prompt, max_length=MAX_TOKENS, temperature=TEMPERATURE):
9
+ """
10
+ Generate a response from the model
11
+ """
12
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
13
+ outputs = model.generate(
14
+ **inputs,
15
+ max_length=max_length,
16
+ do_sample=True,
17
+ temperature=temperature,
18
+ pad_token_id=tokenizer.eos_token_id
19
+ )
20
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
21
+ return response