Ram7379 commited on
Commit
653d26b
·
verified ·
1 Parent(s): 96a9ff4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -6
app.py CHANGED
@@ -2,22 +2,25 @@ import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
 
 
5
  model_name = "microsoft/DialoGPT-medium"
6
 
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = AutoModelForCausalLM.from_pretrained(model_name)
9
 
 
 
10
  def reply(message, history):
11
  if not message.strip():
12
  return "Please enter a message."
13
 
14
  chat_history_ids = None
15
 
16
- # FIX: handle history safely
17
  for msg in history:
18
  content = msg["content"]
19
 
20
- # FIX: sometimes content is list
21
  if isinstance(content, list):
22
  content = " ".join([str(x) for x in content])
23
 
@@ -31,7 +34,7 @@ def reply(message, history):
31
  else:
32
  chat_history_ids = torch.cat([chat_history_ids, ids], dim=-1)
33
 
34
- # current input
35
  new_input_ids = tokenizer.encode(message + tokenizer.eos_token, return_tensors="pt")
36
 
37
  if chat_history_ids is not None:
@@ -39,8 +42,10 @@ def reply(message, history):
39
  else:
40
  input_ids = new_input_ids
41
 
 
42
  attention_mask = torch.ones_like(input_ids)
43
 
 
44
  output_ids = model.generate(
45
  input_ids,
46
  attention_mask=attention_mask,
@@ -53,24 +58,27 @@ def reply(message, history):
53
  repetition_penalty=1.2
54
  )
55
 
56
- # extract response
57
  response_ids = output_ids[:, input_ids.shape[-1]:]
58
  response = tokenizer.decode(response_ids[0], skip_special_tokens=True)
59
 
 
60
  if response.strip() == "":
61
  response = "I'm here! How can I help you?"
62
 
63
  return response
64
 
 
 
65
  demo = gr.ChatInterface(
66
  fn=reply,
67
  title="💬 Smart Dialogue System",
68
- description="Full conversation chatbot (fixed)"
69
  )
70
 
 
71
  demo.launch(
72
  server_name="0.0.0.0",
73
  server_port=7860,
74
  ssr_mode=False
75
- )
76
  )
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
 
5
+ # Load model
6
  model_name = "microsoft/DialoGPT-medium"
7
 
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
9
  model = AutoModelForCausalLM.from_pretrained(model_name)
10
 
11
+
12
+ # Chat function
13
  def reply(message, history):
14
  if not message.strip():
15
  return "Please enter a message."
16
 
17
  chat_history_ids = None
18
 
19
+ # Handle previous conversation
20
  for msg in history:
21
  content = msg["content"]
22
 
23
+ # Fix: if content is list → convert to string
24
  if isinstance(content, list):
25
  content = " ".join([str(x) for x in content])
26
 
 
34
  else:
35
  chat_history_ids = torch.cat([chat_history_ids, ids], dim=-1)
36
 
37
+ # Current message
38
  new_input_ids = tokenizer.encode(message + tokenizer.eos_token, return_tensors="pt")
39
 
40
  if chat_history_ids is not None:
 
42
  else:
43
  input_ids = new_input_ids
44
 
45
+ # Attention mask fix
46
  attention_mask = torch.ones_like(input_ids)
47
 
48
+ # Generate response
49
  output_ids = model.generate(
50
  input_ids,
51
  attention_mask=attention_mask,
 
58
  repetition_penalty=1.2
59
  )
60
 
61
+ # Extract only new response
62
  response_ids = output_ids[:, input_ids.shape[-1]:]
63
  response = tokenizer.decode(response_ids[0], skip_special_tokens=True)
64
 
65
+ # Fallback
66
  if response.strip() == "":
67
  response = "I'm here! How can I help you?"
68
 
69
  return response
70
 
71
+
72
+ # UI
73
  demo = gr.ChatInterface(
74
  fn=reply,
75
  title="💬 Smart Dialogue System",
76
+ description="Full conversation chatbot using DialoGPT"
77
  )
78
 
79
+ # Launch
80
  demo.launch(
81
  server_name="0.0.0.0",
82
  server_port=7860,
83
  ssr_mode=False
 
84
  )