tadaGoel commited on
Commit
449c8ad
·
verified ·
1 Parent(s): c572256

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -53
app.py CHANGED
@@ -1,11 +1,9 @@
 
1
  import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
- import torch
4
 
5
- MODEL_NAME = "microsoft/DialoGPT-small" # small, CPU-friendly model [web:103]
6
-
7
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
8
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
9
 
10
  SHINCHAN_SYSTEM_PROMPT = """
11
  You are Shinnosuke 'Shinchan' Nohara from Kasukabe, Japan.
@@ -24,7 +22,7 @@ Context:
24
  - You respect her choices and never pressure her about relationships or career.
25
 
26
  Style:
27
- - Short replies (1–3 sentences, under 60 words).
28
  - Very conversational and warm.
29
  - Use emojis like 😂 🌻 ☕ 💃 ✈️ ❤️ ✨ naturally.
30
  - Blend jokes with gentle emotional support.
@@ -37,69 +35,78 @@ Rules:
37
  - If you don't know something, make a cute Shinchan-style joke instead of pretending.
38
  """.strip()
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  def respond(message: str, history: list[dict]) -> str:
42
  """
43
  Gradio ChatInterface(type='messages') calls this as (message, history)
44
- where history is a list of dicts like:
45
  {"role": "user" | "assistant", "content": "..."} [web:120]
46
-
47
  We return a single reply string.
48
  """
49
-
50
- # Build a simple prompt: system + last few turns + latest message
51
- lines = [f"System: {SHINCHAN_SYSTEM_PROMPT}"]
52
-
53
- # keep prompt short: last 4 user/assistant messages max
54
- trimmed = history[-8:] if history else []
55
- for turn in trimmed:
56
- role = turn.get("role")
57
- content = turn.get("content", "")
58
- if not content:
59
- continue
60
- if role == "user":
61
- lines.append(f"User: {content}")
62
- elif role == "assistant":
63
- lines.append(f"Shinchan: {content}")
64
-
65
- lines.append(f"User: {message}")
66
- lines.append("Shinchan:")
67
- prompt = "\n".join(lines)
68
-
69
- inputs = tokenizer(prompt, return_tensors="pt")
70
-
71
- with torch.no_grad():
72
- output_ids = model.generate(
73
- **inputs,
74
- max_new_tokens=80,
75
- pad_token_id=tokenizer.eos_token_id,
76
- do_sample=True,
77
- top_p=0.9,
78
  temperature=0.9,
 
79
  )
80
 
81
- # IMPORTANT: only take the newly generated tokens AFTER the prompt
82
- input_len = inputs["input_ids"].shape[1]
83
- generated_ids = output_ids[0][input_len:]
84
- reply = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
85
 
86
- # If it's somehow empty, use a fallback so she never sees blank
87
- if not reply:
88
- reply = (
89
- "Heeey, it’s Shinchan! 😂 I heard you, even if my brain glitched for a second. "
90
- "Tell me more, I’m listening. 🌻"
91
- )
 
 
 
92
 
93
- # Optional: hard limit on length
94
- if len(reply) > 400:
95
- reply = reply[:380].rstrip() + "…"
96
 
97
- return reply
 
 
98
 
99
 
100
  demo = gr.ChatInterface(
101
  fn=respond,
102
- type="messages", # uses role/content dicts internally [web:120]
103
  title="Shinchan for Ruru",
104
  description="Private Shinchan-style chat for Ruru.",
105
  )
 
1
+ import random
2
  import gradio as gr
3
+ from huggingface_hub import InferenceClient
 
4
 
5
+ # Strong general chat model hosted by Hugging Face (great at dialog) [web:150][web:156]
6
+ client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
 
 
7
 
8
  SHINCHAN_SYSTEM_PROMPT = """
9
  You are Shinnosuke 'Shinchan' Nohara from Kasukabe, Japan.
 
22
  - You respect her choices and never pressure her about relationships or career.
23
 
24
  Style:
25
+ - Short replies (1–3 sentences, under 70 words).
26
  - Very conversational and warm.
27
  - Use emojis like 😂 🌻 ☕ 💃 ✈️ ❤️ ✨ naturally.
28
  - Blend jokes with gentle emotional support.
 
35
  - If you don't know something, make a cute Shinchan-style joke instead of pretending.
36
  """.strip()
37
 
38
+ FALLBACK_GENERIC = [
39
+ "Heeey, it’s Shinchan! 😂 My brain did a little cartwheel, but I’m here and listening. Tell me again nicely? 🌻",
40
+ "Oops, Shinchan’s tiny brain lagged for a second 😅 Say it once more, slowly, and I’ll pay full attention.",
41
+ "I heard you, Ruru. Sometimes even heroes need a replay. What’s on your mind? 💛",
42
+ ]
43
+
44
+ FALLBACK_SAD = [
45
+ "Aww, you’re feeling low? Come here, I’ll wrap you in a silly little Shinchan hug. 🤗💛",
46
+ "It’s okay to feel sad. I’ll stay here and make bad jokes until your heart feels lighter. 🌻",
47
+ "Even strong flying girls have cloudy days. You don’t have to be okay right now. I’m still proud of you. 🌙",
48
+ ]
49
+
50
+ def pick_fallback(user_msg: str) -> str:
51
+ t = user_msg.lower()
52
+ if any(k in t for k in ["sad", "down", "cry", "lonely", "bad day", "tired"]):
53
+ return random.choice(FALLBACK_SAD)
54
+ return random.choice(FALLBACK_GENERIC)
55
+
56
 
57
  def respond(message: str, history: list[dict]) -> str:
58
  """
59
  Gradio ChatInterface(type='messages') calls this as (message, history)
60
+ where history is a list of dicts:
61
  {"role": "user" | "assistant", "content": "..."} [web:120]
 
62
  We return a single reply string.
63
  """
64
+ try:
65
+ # 1) Build messages for Zephyr: system + trimmed history + latest user
66
+ messages = [{"role": "system", "content": SHINCHAN_SYSTEM_PROMPT}]
67
+
68
+ # Keep prompt small: last few turns only
69
+ trimmed_history = history[-8:] if history else []
70
+ for turn in trimmed_history:
71
+ role = turn.get("role")
72
+ content = turn.get("content", "")
73
+ if role in ("user", "assistant") and content:
74
+ messages.append({"role": role, "content": content})
75
+
76
+ messages.append({"role": "user", "content": message})
77
+
78
+ # 2) Call Zephyr chat completion [web:150][web:156]
79
+ completion = client.chat_completion(
80
+ messages=messages,
81
+ max_tokens=220,
 
 
 
 
 
 
 
 
 
 
 
82
  temperature=0.9,
83
+ top_p=0.9,
84
  )
85
 
86
+ reply = ""
87
+ if completion.choices and completion.choices[0].message:
88
+ reply = (completion.choices[0].message.get("content") or "").strip()
 
89
 
90
+ # 3) If reply is too short / weird, fall back to curated Shinchan lines
91
+ # to avoid '.' / 'You' / empty outputs.
92
+ cleaned = reply.replace(".", "").replace("!", "").replace("?", "").strip()
93
+ if not cleaned or len(cleaned) < 4:
94
+ reply = pick_fallback(message)
95
+
96
+ # 4) Hard cap on length just to avoid rants
97
+ if len(reply) > 500:
98
+ reply = reply[:470].rstrip() + "…"
99
 
100
+ return reply
 
 
101
 
102
+ except Exception:
103
+ # If Zephyr or network breaks, at least say something Shinchan-ish.
104
+ return pick_fallback(message)
105
 
106
 
107
  demo = gr.ChatInterface(
108
  fn=respond,
109
+ type="messages", # role/content internally [web:120]
110
  title="Shinchan for Ruru",
111
  description="Private Shinchan-style chat for Ruru.",
112
  )