hamza2923 commited on
Commit
05f1d97
·
verified ·
1 Parent(s): 3ae2785

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -0
app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
+
5
+ # Load model and tokenizer
6
+ model_name = "microsoft/DialoGPT-small"
7
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ model = AutoModelForCausalLM.from_pretrained(model_name)
9
+
10
+ def respond(message, chat_history, chat_history_ids):
11
+ # Encode user input
12
+ new_input_ids = tokenizer.encode(message + tokenizer.eos_token, return_tensors="pt")
13
+
14
+ # Append to chat history
15
+ if chat_history_ids is not None:
16
+ input_ids = torch.cat([chat_history_ids, new_input_ids], dim=-1)
17
+ else:
18
+ input_ids = new_input_ids
19
+
20
+ # Generate response
21
+ chat_history_ids = model.generate(
22
+ input_ids,
23
+ max_length=1000,
24
+ pad_token_id=tokenizer.eos_token_id,
25
+ no_repeat_ngram_size=3,
26
+ do_sample=True,
27
+ top_k=50,
28
+ top_p=0.95,
29
+ temperature=0.8
30
+ )
31
+
32
+ # Decode response
33
+ response = tokenizer.decode(
34
+ chat_history_ids[:, input_ids.shape[-1]:][0],
35
+ skip_special_tokens=True
36
+ )
37
+
38
+ # Update conversation history
39
+ chat_history.append((message, response))
40
+
41
+ return "", chat_history, chat_history_ids
42
+
43
+ with gr.Blocks() as demo:
44
+ # Store model's conversation history
45
+ state = gr.State()
46
+
47
+ gr.Markdown("## DialoGPT Chatbot")
48
+ chatbot = gr.Chatbot()
49
+ msg = gr.Textbox(label="Your Message")
50
+ clear = gr.Button("Clear History")
51
+
52
+ msg.submit(
53
+ respond,
54
+ [msg, chatbot, state],
55
+ [msg, chatbot, state]
56
+ )
57
+ clear.click(lambda: (None, None), outputs=[chatbot, state], queue=False)
58
+
59
+ demo.launch()