Prajjwalng commited on
Commit
20bc52b
·
verified ·
1 Parent(s): efc351e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -15
app.py CHANGED
@@ -24,26 +24,35 @@ def load_model():
24
  return tokenizer, model
25
 
26
  tokenizer, model = load_model()
 
 
 
27
 
28
- # Function to generate chatbot response
29
- def generate_response(prompt, chat_history=""):
30
- inputs = tokenizer.encode(chat_history + prompt, return_tensors="pt")
31
 
32
- # Generate a response
33
- outputs = model.generate(
34
- inputs,
35
- max_length=1000,
36
- pad_token_id=tokenizer.eos_token_id,
37
- temperature=0.7,
38
- top_k=50,
39
- top_p=0.95,
40
- )
41
 
42
- response = tokenizer.decode(outputs[0][inputs.shape[-1]:], skip_special_tokens=True)
43
- return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  # Streamlit app
46
- st.title("Gemma-2b-it Chatbot")
47
 
48
  # Initialize chat history
49
  if "messages" not in st.session_state:
 
24
  return tokenizer, model
25
 
26
  tokenizer, model = load_model()
27
+ # Function to generate chatbot response using the provided template
28
+ def get_completion(query: str, model, tokenizer) -> str:
29
+ device = "cuda:0" if torch.cuda.is_available() else "cpu" #Use cuda if available.
30
 
31
+ prompt_template = f"""
32
+ <start_of_turn>system You are a support chatbot who helps with user queries chatbot who always responds in the style of a professional.\n<end_of_turn>
33
+ <start_of_turn>user
34
 
 
 
 
 
 
 
 
 
 
35
 
36
+ {query}
37
+ <end_of_turn>
38
+
39
+ <start_of_turn>model
40
+ """
41
+ prompt = prompt_template.format(query=query)
42
+
43
+ encodeds = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
44
+
45
+ model_inputs = encodeds.to(device)
46
+
47
+ model.to(device) #Move model to device.
48
+
49
+ generated_ids = model.generate(**model_inputs, max_new_tokens=1000, do_sample=True, pad_token_id=tokenizer.eos_token_id)
50
+ decoded = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
51
+ model_response = decoded.split("model\n")[-1].strip()
52
+ return model_response
53
 
54
  # Streamlit app
55
+ st.title("Customer Care Chatbot")
56
 
57
  # Initialize chat history
58
  if "messages" not in st.session_state: