Ram7379 commited on
Commit
21f0fe1
·
verified ·
1 Parent(s): c4e54cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -19
app.py CHANGED
@@ -1,40 +1,65 @@
1
  import gradio as gr
2
- from transformers import pipeline
 
3
 
4
- # Better chatbot model
5
- chatbot = pipeline(
6
- "text2text-generation",
7
- model="google/flan-t5-small"
8
- )
9
 
 
10
  def reply(message, history):
11
  if not message.strip():
12
  return "Please enter a message."
13
 
14
- prompt = f"User: {message}\nAssistant:"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- result = chatbot(
17
- prompt,
18
- max_new_tokens=50,
 
 
19
  do_sample=True,
20
- temperature=0.7
 
 
 
21
  )
22
 
23
- output = result[0]["generated_text"]
 
 
24
 
25
- # Extract only assistant response
26
- if "Assistant:" in output:
27
- response = output.split("Assistant:")[-1].strip()
28
- else:
29
- response = output.strip()
30
 
31
- return response # ✅ ONLY RETURN STRING
32
 
33
  # Chat UI
34
  demo = gr.ChatInterface(
35
  fn=reply,
36
  title="💬 Smart Dialogue System",
37
- description="Chatbot using FLAN-T5"
38
  )
39
 
40
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import torch
4
 
5
+ # Load chat model
6
+ model_name = "microsoft/DialoGPT-medium"
7
+
8
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
9
+ model = AutoModelForCausalLM.from_pretrained(model_name)
10
 
11
+ # Chat function with memory
12
  def reply(message, history):
13
  if not message.strip():
14
  return "Please enter a message."
15
 
16
+ # Build conversation history
17
+ chat_history_ids = None
18
+
19
+ for user, bot in history:
20
+ user_ids = tokenizer.encode(user + tokenizer.eos_token, return_tensors="pt")
21
+ bot_ids = tokenizer.encode(bot + tokenizer.eos_token, return_tensors="pt")
22
+ chat_history_ids = (
23
+ user_ids if chat_history_ids is None
24
+ else torch.cat([chat_history_ids, user_ids], dim=-1)
25
+ )
26
+ chat_history_ids = torch.cat([chat_history_ids, bot_ids], dim=-1)
27
+
28
+ # Add current user message
29
+ new_input_ids = tokenizer.encode(message + tokenizer.eos_token, return_tensors="pt")
30
+
31
+ input_ids = (
32
+ new_input_ids if chat_history_ids is None
33
+ else torch.cat([chat_history_ids, new_input_ids], dim=-1)
34
+ )
35
 
36
+ # Generate response
37
+ output_ids = model.generate(
38
+ input_ids,
39
+ max_length=1000,
40
+ pad_token_id=tokenizer.eos_token_id,
41
  do_sample=True,
42
+ top_k=50,
43
+ top_p=0.95,
44
+ temperature=0.7,
45
+ repetition_penalty=1.2 # 🔥 stops repetition
46
  )
47
 
48
+ # Extract only new response
49
+ response_ids = output_ids[:, input_ids.shape[-1]:]
50
+ response = tokenizer.decode(response_ids[0], skip_special_tokens=True)
51
 
52
+ # Clean fallback
53
+ if response.strip() == "":
54
+ response = "I'm here! How can I help you?"
 
 
55
 
56
+ return response
57
 
58
  # Chat UI
59
  demo = gr.ChatInterface(
60
  fn=reply,
61
  title="💬 Smart Dialogue System",
62
+ description="Chatbot using DialoGPT (context-aware)"
63
  )
64
 
65
  demo.launch(server_name="0.0.0.0", server_port=7860)