eabybabu commited on
Commit
5cfc235
Β·
1 Parent(s): 4729206

Optimized chatbot for speed with CUDA & quantization

Browse files
Files changed (1) hide show
  1. app.py +19 -18
app.py CHANGED
@@ -1,43 +1,46 @@
1
  import os
 
2
  import gradio as gr
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
 
5
  # βœ… Load API Token Securely from Hugging Face Secrets
6
  HF_TOKEN = os.getenv("HF_TOKEN")
7
 
8
- # βœ… Load model and tokenizer from Hugging Face Model Hub
9
  MODEL_NAME = "eabybabu/chatbot_model" # Replace with your actual model name
10
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN)
11
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, token=HF_TOKEN)
12
 
13
- # βœ… Function to generate chatbot responses while maintaining chat history
 
 
 
 
 
 
 
14
  def chatbot_response(user_input, chat_history):
15
  try:
16
- # Combine chat history with new query
17
  chat_context = " ".join([f"User: {msg}\nChatbot: {resp}" for msg, resp in chat_history])
18
  prompt = f"{chat_context}\nUser: {user_input}\nChatbot:"
19
 
20
  # Encode input
21
- inputs = tokenizer.encode(prompt, return_tensors="pt")
22
 
23
- # Generate response
24
  outputs = model.generate(
25
  inputs,
26
- max_length=300, # Control response length
27
- temperature=0.7, # Controls randomness
28
- top_k=50, # Limits token selection
29
- top_p=0.9, # Nucleus sampling
30
- repetition_penalty=1.5, # Prevents repetition
31
- num_return_sequences=1 # Return one response
32
  )
33
 
34
  # Decode response
35
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
36
 
37
- # Clean up response (remove repeated parts)
38
- response = ". ".join(set(response.split(". ")))
39
-
40
- # Append new message to history
41
  chat_history.append((user_input, response))
42
 
43
  return chat_history, ""
@@ -54,10 +57,8 @@ with gr.Blocks() as demo:
54
  user_input = gr.Textbox(label="Type your question:")
55
  submit_btn = gr.Button("Ask Chatbot")
56
 
57
- # Initialize chat history
58
  chat_history = gr.State([])
59
 
60
- # Connect button to chatbot function
61
  submit_btn.click(chatbot_response, inputs=[user_input, chat_history], outputs=[chatbot, user_input])
62
 
63
  # βœ… Launch the Gradio app
 
1
  import os
2
+ import torch
3
  import gradio as gr
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
 
6
  # βœ… Load API Token Securely from Hugging Face Secrets
7
  HF_TOKEN = os.getenv("HF_TOKEN")
8
 
9
+ # βœ… Load model and tokenizer (Optimized for Speed)
10
  MODEL_NAME = "eabybabu/chatbot_model" # Replace with your actual model name
11
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN)
 
12
 
13
+ # βœ… Use GPU if available
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+
16
+ # βœ… Load model and apply quantization (if available)
17
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, token=HF_TOKEN).to(device)
18
+ model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8) # Apply quantization
19
+
20
+ # βœ… Function to generate chatbot responses with chat history
21
  def chatbot_response(user_input, chat_history):
22
  try:
 
23
  chat_context = " ".join([f"User: {msg}\nChatbot: {resp}" for msg, resp in chat_history])
24
  prompt = f"{chat_context}\nUser: {user_input}\nChatbot:"
25
 
26
  # Encode input
27
+ inputs = tokenizer.encode(prompt, return_tensors="pt").to(device)
28
 
29
+ # Generate response (Faster with CUDA & Optimized Settings)
30
  outputs = model.generate(
31
  inputs,
32
+ max_length=200,
33
+ temperature=0.7,
34
+ top_k=50,
35
+ top_p=0.9,
36
+ repetition_penalty=1.5,
37
+ num_return_sequences=1
38
  )
39
 
40
  # Decode response
41
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
42
+ response = ". ".join(set(response.split(". "))) # Prevent repetition
43
 
 
 
 
 
44
  chat_history.append((user_input, response))
45
 
46
  return chat_history, ""
 
57
  user_input = gr.Textbox(label="Type your question:")
58
  submit_btn = gr.Button("Ask Chatbot")
59
 
 
60
  chat_history = gr.State([])
61
 
 
62
  submit_btn.click(chatbot_response, inputs=[user_input, chat_history], outputs=[chatbot, user_input])
63
 
64
  # βœ… Launch the Gradio app