Jitendra14355 commited on
Commit
3b96f8b
·
verified ·
1 Parent(s): e96ac87

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -45
app.py CHANGED
@@ -1,5 +1,5 @@
1
  # =========================================
2
- # Dialogue System using DialoGPT (Gradio)
3
  # =========================================
4
 
5
  import gradio as gr
@@ -7,9 +7,9 @@ import torch
7
  from transformers import AutoModelForCausalLM, AutoTokenizer
8
 
9
  # -----------------------------
10
- # 1. Load Model & Tokenizer
11
  # -----------------------------
12
- MODEL_NAME = "microsoft/DialoGPT-medium"
13
 
14
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
15
  model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
@@ -17,21 +17,21 @@ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
  model = model.to(device)
19
 
20
- # -----------------------------
21
- # 2. Chat Function
22
- # -----------------------------
23
  chat_history_ids = None
24
 
25
- def chat(user_input, history):
 
 
 
26
  global chat_history_ids
27
 
28
- if not user_input.strip():
29
- return history, ""
30
 
31
- # Encode user input
32
- new_input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors="pt").to(device)
33
 
34
- # Append to history
35
  if chat_history_ids is not None:
36
  bot_input_ids = torch.cat([chat_history_ids, new_input_ids], dim=-1)
37
  else:
@@ -40,7 +40,7 @@ def chat(user_input, history):
40
  # Generate response
41
  chat_history_ids = model.generate(
42
  bot_input_ids,
43
- max_length=1000,
44
  pad_token_id=tokenizer.eos_token_id,
45
  do_sample=True,
46
  top_k=50,
@@ -48,13 +48,16 @@ def chat(user_input, history):
48
  temperature=0.7
49
  )
50
 
51
- # Decode response
52
  response = tokenizer.decode(
53
  chat_history_ids[:, bot_input_ids.shape[-1]:][0],
54
  skip_special_tokens=True
55
  )
56
 
57
- history.append((user_input, response))
 
 
 
58
  return history, ""
59
 
60
  # -----------------------------
@@ -66,44 +69,25 @@ def reset_chat():
66
  return [], ""
67
 
68
  # -----------------------------
69
- # 4. Gradio UI
70
  # -----------------------------
71
- with gr.Blocks(title="Dialogue System") as app:
 
72
 
73
- gr.Markdown("## 🤖 AI Dialogue System (Chatbot)")
74
- gr.Markdown("Chat with an AI using DialoGPT")
75
 
76
- chatbot = gr.Chatbot()
77
 
78
  with gr.Row():
79
- user_input = gr.Textbox(
80
- placeholder="Type your message...",
81
- show_label=False
82
- )
83
 
84
- with gr.Row():
85
- send_btn = gr.Button("Send")
86
- clear_btn = gr.Button("Clear Chat")
87
-
88
- # Button actions
89
- send_btn.click(
90
- chat,
91
- inputs=[user_input, chatbot],
92
- outputs=[chatbot, user_input]
93
- )
94
 
95
- user_input.submit(
96
- chat,
97
- inputs=[user_input, chatbot],
98
- outputs=[chatbot, user_input]
99
- )
100
-
101
- clear_btn.click(
102
- reset_chat,
103
- outputs=[chatbot, user_input]
104
- )
105
 
106
  # -----------------------------
107
  # 5. Launch
108
  # -----------------------------
109
- app.launch()
 
1
  # =========================================
2
+ # Dialogue System (Gradio FIXED VERSION)
3
  # =========================================
4
 
5
  import gradio as gr
 
7
  from transformers import AutoModelForCausalLM, AutoTokenizer
8
 
9
  # -----------------------------
10
+ # 1. Load Model
11
  # -----------------------------
12
+ MODEL_NAME = "microsoft/DialoGPT-small" # lighter & faster
13
 
14
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
15
  model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
 
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
  model = model.to(device)
19
 
 
 
 
20
  chat_history_ids = None
21
 
22
+ # -----------------------------
23
+ # 2. Chat Function (FIXED)
24
+ # -----------------------------
25
+ def chat(message, history):
26
  global chat_history_ids
27
 
28
+ if history is None:
29
+ history = []
30
 
31
+ # Encode input
32
+ new_input_ids = tokenizer.encode(message + tokenizer.eos_token, return_tensors="pt").to(device)
33
 
34
+ # Append history
35
  if chat_history_ids is not None:
36
  bot_input_ids = torch.cat([chat_history_ids, new_input_ids], dim=-1)
37
  else:
 
40
  # Generate response
41
  chat_history_ids = model.generate(
42
  bot_input_ids,
43
+ max_length=500,
44
  pad_token_id=tokenizer.eos_token_id,
45
  do_sample=True,
46
  top_k=50,
 
48
  temperature=0.7
49
  )
50
 
51
+ # Decode
52
  response = tokenizer.decode(
53
  chat_history_ids[:, bot_input_ids.shape[-1]:][0],
54
  skip_special_tokens=True
55
  )
56
 
57
+ # ✅ NEW FORMAT (IMPORTANT FIX)
58
+ history.append({"role": "user", "content": message})
59
+ history.append({"role": "assistant", "content": response})
60
+
61
  return history, ""
62
 
63
  # -----------------------------
 
69
  return [], ""
70
 
71
  # -----------------------------
72
+ # 4. UI (NEW CHAT INTERFACE)
73
  # -----------------------------
74
+ with gr.Blocks() as demo:
75
+ gr.Markdown("## 🤖 AI Dialogue System")
76
 
77
+ chatbot = gr.Chatbot(type="messages") # IMPORTANT
 
78
 
79
+ msg = gr.Textbox(placeholder="Type your message...")
80
 
81
  with gr.Row():
82
+ send = gr.Button("Send")
83
+ clear = gr.Button("Clear Chat")
 
 
84
 
85
+ send.click(chat, [msg, chatbot], [chatbot, msg])
86
+ msg.submit(chat, [msg, chatbot], [chatbot, msg])
 
 
 
 
 
 
 
 
87
 
88
+ clear.click(reset_chat, outputs=[chatbot, msg])
 
 
 
 
 
 
 
 
 
89
 
90
  # -----------------------------
91
  # 5. Launch
92
  # -----------------------------
93
+ demo.launch()