tadaGoel commited on
Commit
3f5e2e0
·
verified ·
1 Parent(s): 59b8bf4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -42
app.py CHANGED
@@ -2,7 +2,7 @@ import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
5
- MODEL_NAME = "microsoft/DialoGPT-small" # small, free, OK on CPU [web:103]
6
 
7
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
8
  model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
@@ -39,53 +39,59 @@ Rules:
39
 
40
 
41
  def build_prompt(message, history):
42
- """Turn chat history into a single text prompt for DialoGPT."""
43
- lines = [f"System: {SHINCHAN_SYSTEM_PROMPT}"]
44
- for user_msg, bot_msg in history:
45
- if user_msg:
46
- lines.append(f"User: {user_msg}")
47
- if bot_msg:
48
- lines.append(f"Shinchan: {bot_msg}")
49
- lines.append(f"User: {message}")
50
- lines.append("Shinchan:")
51
- return "\n".join(lines)
 
 
 
 
 
 
52
 
53
 
54
  def respond(message, history):
55
- if not history:
56
- history = []
57
-
58
- prompt = build_prompt(message, history)
59
- inputs = tokenizer(prompt, return_tensors="pt")
60
-
61
- # keep it small/light for CPU Spaces
62
- with torch.no_grad():
63
- output_ids = model.generate(
64
- **inputs,
65
- max_new_tokens=80,
66
- pad_token_id=tokenizer.eos_token_id,
67
- do_sample=True,
68
- top_p=0.9,
69
- temperature=0.9,
70
- )
71
-
72
- full_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
73
- # get only the last Shinchan part after final "Shinchan:"
74
- if "Shinchan:" in full_text:
75
- reply = full_text.split("Shinchan:")[-1].strip()
76
- else:
77
- reply = full_text.strip()
78
-
79
- # add reply to history; Gradio ChatInterface expects (history, response)
80
- history = history + [(message, reply)]
81
- return reply, history
82
 
83
 
84
  demo = gr.ChatInterface(
85
- respond,
86
- title="Shinchan for Ruru",
87
- description="Private Shinchan-style chat for Ruru.",
 
88
  )
89
 
90
  if __name__ == "__main__":
91
- demo.launch()
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
5
+ MODEL_NAME = "microsoft/DialoGPT-small" # small, CPU-friendly
6
 
7
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
8
  model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
 
39
 
40
 
41
  def build_prompt(message, history):
42
+ """
43
+ Gradio ChatInterface (type='messages') passes history as a list of
44
+ dicts like: {'role': 'user'|'assistant', 'content': '...'}.[web:120]
45
+ Turn that into a plain-text prompt for DialoGPT.
46
+ """
47
+ lines = [f"System: {SHINCHAN_SYSTEM_PROMPT}"]
48
+ for turn in history:
49
+ role = turn.get("role")
50
+ content = turn.get("content", "")
51
+ if role == "user":
52
+ lines.append(f"User: {content}")
53
+ elif role == "assistant":
54
+ lines.append(f"Shinchan: {content}")
55
+ lines.append(f"User: {message}")
56
+ lines.append("Shinchan:")
57
+ return "\n".join(lines)
58
 
59
 
60
  def respond(message, history):
61
+ """
62
+ ChatInterface expects: fn(message:str, history:list[dict]) -> str or dict or list.[web:120]
63
+ Do NOT return (reply, history) here – that caused your 'tuple has no attribute get' error.
64
+ """
65
+ prompt = build_prompt(message, history)
66
+ inputs = tokenizer(prompt, return_tensors="pt")
67
+
68
+ with torch.no_grad():
69
+ output_ids = model.generate(
70
+ **inputs,
71
+ max_new_tokens=80,
72
+ pad_token_id=tokenizer.eos_token_id,
73
+ do_sample=True,
74
+ top_p=0.9,
75
+ temperature=0.9,
76
+ )
77
+
78
+ full_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
79
+
80
+ if "Shinchan:" in full_text:
81
+ reply = full_text.split("Shinchan:")[-1].strip()
82
+ else:
83
+ reply = full_text.strip()
84
+
85
+ # IMPORTANT: return just the reply string
86
+ return reply
 
87
 
88
 
89
  demo = gr.ChatInterface(
90
+ fn=respond,
91
+ type="messages", # explicit; default since Gradio 5 but makes intent clear[web:120]
92
+ title="Shinchan for Ruru",
93
+ description="Private Shinchan-style chat for Ruru.",
94
  )
95
 
96
  if __name__ == "__main__":
97
+ demo.launch()