Jitendra14355 commited on
Commit
dc3e00d
·
verified ·
1 Parent(s): 6e1e4ff

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -0
app.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =========================================
2
+ # Dialogue System using DialoGPT (Gradio)
3
+ # =========================================
4
+
5
+ import gradio as gr
6
+ import torch
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
+
9
+ # -----------------------------
10
+ # 1. Load Model & Tokenizer
11
+ # -----------------------------
12
+ MODEL_NAME = "microsoft/DialoGPT-medium"
13
+
14
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
15
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
16
+
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+ model = model.to(device)
19
+
20
+ # -----------------------------
21
+ # 2. Chat Function
22
+ # -----------------------------
23
+ chat_history_ids = None
24
+
25
+ def chat(user_input, history):
26
+ global chat_history_ids
27
+
28
+ if not user_input.strip():
29
+ return history, ""
30
+
31
+ # Encode user input
32
+ new_input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors="pt").to(device)
33
+
34
+ # Append to history
35
+ if chat_history_ids is not None:
36
+ bot_input_ids = torch.cat([chat_history_ids, new_input_ids], dim=-1)
37
+ else:
38
+ bot_input_ids = new_input_ids
39
+
40
+ # Generate response
41
+ chat_history_ids = model.generate(
42
+ bot_input_ids,
43
+ max_length=1000,
44
+ pad_token_id=tokenizer.eos_token_id,
45
+ do_sample=True,
46
+ top_k=50,
47
+ top_p=0.95,
48
+ temperature=0.7
49
+ )
50
+
51
+ # Decode response
52
+ response = tokenizer.decode(
53
+ chat_history_ids[:, bot_input_ids.shape[-1]:][0],
54
+ skip_special_tokens=True
55
+ )
56
+
57
+ history.append((user_input, response))
58
+ return history, ""
59
+
60
+ # -----------------------------
61
+ # 3. Reset Function
62
+ # -----------------------------
63
+ def reset_chat():
64
+ global chat_history_ids
65
+ chat_history_ids = None
66
+ return [], ""
67
+
68
+ # -----------------------------
69
+ # 4. Gradio UI
70
+ # -----------------------------
71
+ with gr.Blocks(title="Dialogue System") as app:
72
+
73
+ gr.Markdown("## 🤖 AI Dialogue System (Chatbot)")
74
+ gr.Markdown("Chat with an AI using DialoGPT")
75
+
76
+ chatbot = gr.Chatbot()
77
+
78
+ with gr.Row():
79
+ user_input = gr.Textbox(
80
+ placeholder="Type your message...",
81
+ show_label=False
82
+ )
83
+
84
+ with gr.Row():
85
+ send_btn = gr.Button("Send")
86
+ clear_btn = gr.Button("Clear Chat")
87
+
88
+ # Button actions
89
+ send_btn.click(
90
+ chat,
91
+ inputs=[user_input, chatbot],
92
+ outputs=[chatbot, user_input]
93
+ )
94
+
95
+ user_input.submit(
96
+ chat,
97
+ inputs=[user_input, chatbot],
98
+ outputs=[chatbot, user_input]
99
+ )
100
+
101
+ clear_btn.click(
102
+ reset_chat,
103
+ outputs=[chatbot, user_input]
104
+ )
105
+
106
+ # -----------------------------
107
+ # 5. Launch
108
+ # -----------------------------
109
+ app.launch()