YUGISUNG commited on
Commit
d206c19
·
verified ·
1 Parent(s): 12dcf5e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -6
app.py CHANGED
@@ -2,15 +2,27 @@ import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
 
5
  model_name = "microsoft/DialoGPT-medium"
6
  tokenizer = AutoTokenizer.from_pretrained(model_name)
7
  model = AutoModelForCausalLM.from_pretrained(model_name)
8
 
9
  chat_history_ids = None
10
 
11
- def chatbot(input_text):
 
 
 
 
 
 
 
 
12
  global chat_history_ids
13
- new_input_ids = tokenizer.encode(input_text + tokenizer.eos_token, return_tensors='pt')
 
 
 
14
  bot_input_ids = torch.cat([chat_history_ids, new_input_ids], dim=-1) if chat_history_ids is not None else new_input_ids
15
 
16
  chat_history_ids = model.generate(
@@ -18,18 +30,24 @@ def chatbot(input_text):
18
  max_length=1000,
19
  pad_token_id=tokenizer.eos_token_id,
20
  do_sample=True,
21
- temperature=0.7
 
22
  )
23
 
24
  output = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
25
  return output
26
 
 
27
  iface = gr.Interface(
28
  fn=chatbot,
29
- inputs=gr.Textbox(lines=2, placeholder="Say something..."),
 
 
 
30
  outputs="text",
31
- title="Reliable Chatbot",
32
- description="Powered by DialoGPT works well, no fuss."
33
  )
34
 
35
  iface.launch(share=True)
 
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
5
+ # Load model and tokenizer
6
  model_name = "microsoft/DialoGPT-medium"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = AutoModelForCausalLM.from_pretrained(model_name)
9
 
10
  chat_history_ids = None
11
 
12
+ # Generic personas
13
+ persona_prompts = {
14
+ "Friendly": "You are a kind, friendly chatbot who is always positive and cheerful.",
15
+ "Professional": "You are a professional assistant who answers clearly and respectfully.",
16
+ "Sarcastic": "You are a sarcastic AI who responds with wit and dry humor.",
17
+ "Motivational Coach": "You are a motivational coach who always encourages and inspires.",
18
+ }
19
+
20
+ def chatbot(persona, input_text):
21
  global chat_history_ids
22
+ persona_instruction = persona_prompts.get(persona, "")
23
+ full_input = f"{persona_instruction}\nUser: {input_text}"
24
+
25
+ new_input_ids = tokenizer.encode(full_input + tokenizer.eos_token, return_tensors='pt')
26
  bot_input_ids = torch.cat([chat_history_ids, new_input_ids], dim=-1) if chat_history_ids is not None else new_input_ids
27
 
28
  chat_history_ids = model.generate(
 
30
  max_length=1000,
31
  pad_token_id=tokenizer.eos_token_id,
32
  do_sample=True,
33
+ temperature=0.7,
34
+ top_p=0.9
35
  )
36
 
37
  output = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
38
  return output
39
 
40
+ # Gradio UI
41
  iface = gr.Interface(
42
  fn=chatbot,
43
+ inputs=[
44
+ gr.Dropdown(choices=list(persona_prompts.keys()), label="Choose a Persona"),
45
+ gr.Textbox(lines=2, placeholder="Say something...")
46
+ ],
47
  outputs="text",
48
+ title="Persona Bot (DialoGPT)",
49
+ description="Choose a simple chatbot persona: Friendly, Professional, Sarcastic, or Motivational Coach."
50
  )
51
 
52
  iface.launch(share=True)
53
+