hamza2923 commited on
Commit
6fe40f1
·
verified ·
1 Parent(s): ddf305b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -37
app.py CHANGED
@@ -5,55 +5,71 @@ import torch
5
  # Load model and tokenizer
6
  model_name = "microsoft/DialoGPT-small"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
8
  model = AutoModelForCausalLM.from_pretrained(model_name)
9
 
 
 
 
 
10
  def respond(message, chat_history, chat_history_ids):
11
- # Encode user input
12
- new_input_ids = tokenizer.encode(message + tokenizer.eos_token, return_tensors="pt")
13
-
14
- # Append to chat history
15
- if chat_history_ids is not None:
16
- input_ids = torch.cat([chat_history_ids, new_input_ids], dim=-1)
17
- else:
18
- input_ids = new_input_ids
19
-
20
- # Generate response
21
- chat_history_ids = model.generate(
22
- input_ids,
23
- max_length=1000,
24
- pad_token_id=tokenizer.eos_token_id,
25
- no_repeat_ngram_size=3,
26
- do_sample=True,
27
- top_k=50,
28
- top_p=0.95,
29
- temperature=0.8
30
- )
31
-
32
- # Decode response
33
- response = tokenizer.decode(
34
- chat_history_ids[:, input_ids.shape[-1]:][0],
35
- skip_special_tokens=True
36
- )
37
 
38
- # Update conversation history
39
- chat_history.append((message, response))
40
 
41
- return "", chat_history, chat_history_ids
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  with gr.Blocks() as demo:
44
- # Store model's conversation history
45
  state = gr.State()
46
-
47
  gr.Markdown("## DialoGPT Chatbot")
48
  chatbot = gr.Chatbot()
49
- msg = gr.Textbox(label="Your Message")
50
  clear = gr.Button("Clear History")
 
51
 
52
  msg.submit(
53
  respond,
54
- [msg, chatbot, state],
55
- [msg, chatbot, state]
56
  )
57
- clear.click(lambda: (None, None), outputs=[chatbot, state], queue=False)
58
-
59
- demo.launch()
 
 
 
 
5
  # Load model and tokenizer
6
  model_name = "microsoft/DialoGPT-small"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ tokenizer.pad_token = tokenizer.eos_token
9
  model = AutoModelForCausalLM.from_pretrained(model_name)
10
 
11
+ # Move model to GPU if available
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ model = model.to(device)
14
+
15
  def respond(message, chat_history, chat_history_ids):
16
+ if not message.strip():
17
+ return "", chat_history or [], chat_history_ids, "Please enter a message."
18
+
19
+ if chat_history is None:
20
+ chat_history = []
21
+
22
+ new_input_ids = tokenizer.encode(message + tokenizer.eos_token, return_tensors="pt").to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
+ input_ids = torch.cat([chat_history_ids, new_input_ids], dim=-1) if chat_history_ids is not None else new_input_ids
 
25
 
26
+ try:
27
+ chat_history_ids = model.generate(
28
+ input_ids,
29
+ max_length=200,
30
+ pad_token_id=tokenizer.eos_token_id,
31
+ no_repeat_ngram_size=3,
32
+ do_sample=True,
33
+ top_k=50,
34
+ top_p=0.95,
35
+ temperature=0.8
36
+ )
37
+
38
+ response = tokenizer.decode(
39
+ chat_history_ids[:, input_ids.shape[-1]:][0],
40
+ skip_special_tokens=True
41
+ )
42
+
43
+ chat_history.append((message, response))
44
+
45
+ if len(chat_history) > 10:
46
+ chat_history = chat_history[-10:]
47
+ history_text = "".join([msg + resp + tokenizer.eos_token for msg, resp in chat_history])
48
+ chat_history_ids = tokenizer.encode(history_text, return_tensors="pt").to(device)
49
+
50
+ return "", chat_history, chat_history_ids, None
51
+ except Exception as e:
52
+ return "", chat_history, chat_history_ids, f"Error: {str(e)}"
53
+
54
+ def clear_history():
55
+ return [], None, None
56
 
57
  with gr.Blocks() as demo:
 
58
  state = gr.State()
 
59
  gr.Markdown("## DialoGPT Chatbot")
60
  chatbot = gr.Chatbot()
61
+ msg = gr.Textbox(label="Your Message", placeholder="Type your message here...")
62
  clear = gr.Button("Clear History")
63
+ error = gr.Textbox(label="Error", interactive=False, visible=False)
64
 
65
  msg.submit(
66
  respond,
67
+ inputs=[msg, chatbot, state],
68
+ outputs=[msg, chatbot, state, error]
69
  )
70
+ clear.click(
71
+ fn=clear_history,
72
+ inputs=None,
73
+ outputs=[chatbot, state, error],
74
+ queue=False
75
+ )