File size: 6,786 Bytes
d511a4d
971d379
064d610
bc9d6cb
bafbae3
0e21d39
d511a4d
91fd18e
 
8d31829
91fd18e
 
d511a4d
91fd18e
9779c60
91fd18e
d8f2828
 
 
 
 
 
d511a4d
bafbae3
8d31829
 
91fd18e
1179ade
 
 
9c4ec4e
 
 
1179ade
 
bafbae3
 
 
 
1179ade
8d31829
e645da6
8d31829
9c4ec4e
7f620c5
 
 
9c4ec4e
 
 
 
 
 
 
bafbae3
d511a4d
bafbae3
 
9779c60
d511a4d
 
8d31829
0e21d39
d511a4d
9779c60
bafbae3
 
 
d8f2828
 
 
 
 
 
 
 
bafbae3
 
971d379
bafbae3
 
 
8d31829
 
bafbae3
 
971d379
 
8d31829
 
971d379
 
 
 
 
 
 
 
 
8d31829
 
 
971d379
 
 
8d31829
 
971d379
 
 
8d31829
 
 
bafbae3
 
8d31829
d511a4d
9779c60
 
8d31829
 
9c4ec4e
91fd18e
 
9779c60
 
bafbae3
064d610
8d31829
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bafbae3
8d31829
064d610
bafbae3
8d31829
 
 
e645da6
 
e4c20b9
 
e645da6
 
 
f04be6a
 
8d31829
f04be6a
 
 
 
 
 
 
 
8d31829
 
 
f04be6a
8d31829
 
d511a4d
f04be6a
ec284b1
064d610
d8f2828
7f620c5
 
d8f2828
7f620c5
8d31829
d8f2828
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import os
import re
import gradio as gr

# --------- CPU hygiene (nice-to-have) ----------
os.environ["TOKENIZERS_PARALLELISM"] = "false"
try:
    import torch
    try:
        torch.set_num_threads(2)
    except Exception:
        pass
except Exception:
    pass

from transformers import pipeline

# vaderSentiment (graceful optional import)
try:
    from vaderSentiment.vaderSentiment import SentimentIntensityAnalyzer
except Exception:
    SentimentIntensityAnalyzer = None

# -------- Model / bot configuration --------
GEN_MODEL_NAME = "MBZUAI/LaMini-Flan-T5-248M"
WELCOME_IMAGE_PATH = "assets/cat.png"

DOMAIN_INSTRUCTIONS = (
    "You are a concise assistant about cats in ancient Egypt. "
    "Keep focus on Bastet, cat mummies, daily life, worship, and other ancient Egypt facts. "
    "If the user asks something unrelated, say briefly that you only cover those topics and suggest one. "
    "Do not include greetings or apologies and do not say phrases like 'as an AI language model'. "
    "Start directly with the answer."
)

HELP_TEXT = (
    "Ask me about: Bastet • cat mummies • daily life • worship\n"
    "Type anything else to try the AI fallback."
)

WELCOME_TEXT = "Hi! I share facts about cats in ancient Egypt.\n\n" + HELP_TEXT

# -------- Output cleanup --------
DISCLAIMER_PATTERNS = [
    r"^\s*(hi|hello|hey)[,!.?\s-]*",
    r"^\s*i'?m\s+sorr(y|ied)[^.\n]*[.\n]*",
    r"^\s*as an ai language model[^.\n]*[.\n]*"
]
def strip_preamble(text: str) -> str:
    t = text or ""
    for pat in DISCLAIMER_PATTERNS:
        t = re.sub(pat, "", t, flags=re.IGNORECASE)
    return t.strip()

# -------- Lazy singletons --------
_t2t = None
_vader = None

def get_t2t():
    global _t2t
    if _t2t is None:
        _t2t = pipeline("text2text-generation", model=GEN_MODEL_NAME, tokenizer=GEN_MODEL_NAME)
        print(f"[startup] Loaded model: {GEN_MODEL_NAME}")
    return _t2t

def get_vader():
    global _vader
    if _vader is None:
        if SentimentIntensityAnalyzer is None:
            print("[startup] VADER not installed; using neutral sentiment fallback")
            class _NeutralVader:
                def polarity_scores(self, _): return {"compound": 0.0}
            _vader = _NeutralVader()
        else:
            _vader = SentimentIntensityAnalyzer()
            print("[startup] Loaded VADER sentiment analyzer")
    return _vader

# -------- Sentiment helpers --------
def detect_sentiment_bucket(text: str):
    scores = get_vader().polarity_scores(text or "")
    c = scores.get("compound", 0.0)
    if c <= -0.4: return "neg", c
    if c >= 0.4:  return "pos", c
    return "neu", c

def is_question(text: str) -> bool:
    t = (text or "").strip()
    if "?" in t: return True
    return bool(re.match(r"^(who|what|when|where|why|how|do|does|did|can|could|is|are|was|were|should|would|will)\b", t.lower()))

def is_thanks_or_praise(text: str) -> bool:
    t = (text or "").lower()
    return any(k in t for k in [
        "thanks", "thank you", "appreciate", "appreciated",
        "great answer", "nice", "awesome", "love", "helpful",
        "that helped", "that was good", "i like your response"
    ])

POS_QUESTION_PREFIXES = ["Good question! ", "Nice one—here’s the gist: ", "Let’s dig in. "]
POS_PRAISE_PREFIXES   = ["You’re welcome—glad that helped. ", "Appreciate the kind words! ", "Happy it was useful. "]
POS_STATEMENT_PREFIXES= ["Sounds good. ", "Got it. ", "All right—here’s the short version. "]
NEG_PREFIX = "Calm down. You're being a little too negative! "

def choose_positive_prefix(message: str) -> str:
    if is_thanks_or_praise(message): return POS_PRAISE_PREFIXES[0]
    if is_question(message):         return POS_QUESTION_PREFIXES[0]
    return POS_STATEMENT_PREFIXES[0]

def apply_tone_prefix(reply_text: str, bucket: str, message: str = "") -> str:
    if bucket == "pos":   prefix = choose_positive_prefix(message)
    elif bucket == "neg": prefix = NEG_PREFIX
    else:                 prefix = ""
    return (prefix + (reply_text or "")).strip()

# ---- LLM fallback ----
def ai_fallback(prompt: str) -> str:
    try:
        gen = get_t2t()
        prefixed = f"{DOMAIN_INSTRUCTIONS}\n\nUser: {prompt}\nAssistant:"
        out = gen(prefixed, max_new_tokens=48, do_sample=False, no_repeat_ngram_size=3)[0]["generated_text"]
        return strip_preamble(out)
    except Exception as e:
        print("AI fallback error:", repr(e))
        return "AI fallback had an issue. Please try a simpler question or use the topics in 'help'."

# -------- Chat logic --------
def reply(message, history):
    bucket, _score = detect_sentiment_bucket(message or "")
    msg = (message or "").strip().lower()

    if re.search(r"\b(hi|hello|hey|hiya|yo|greetings)\b", msg) or any(k in msg for k in ["help","menu","topics","instructions"]):
        base = "Hi! I share facts about cats in ancient Egypt.\n\n" + HELP_TEXT
    elif "bastet" in msg or "bast" in msg:
        base = "Bastet (later cat-headed) … major cult center at Bubastis in the Nile Delta."
    elif any(w in msg for w in ["mummy","mummies","mummified","offering"]):
        base = "Millions of animal mummies (cats common), esp. Late Period (664–332 BCE)."
    elif any(w in msg for w in ["daily","life","pest","mouse","rat","snake"]):
        base = "Cats protected grain stores; art shows them under chairs/on leashes with owners."
    elif any(w in msg for w in ["worship","god","goddess","taboo"]):
        base = "People didn’t worship pet cats as gods; they revered cats via Bastet and votive offerings."
    else:
        base = ai_fallback(message)

    return apply_tone_prefix(base, bucket, message)

# -------- UI --------
initial_messages = [
    {"role": "assistant", "content": WELCOME_TEXT}  # text only in the chat history (stable)
]

chatbot_component = gr.Chatbot(
    type="messages",
    value=initial_messages,
    show_label=False,
)

# UI (banner is collapsible so it won't push the input off-screen)
with gr.Blocks(fill_height=True) as demo:
    if os.path.exists(WELCOME_IMAGE_PATH):
        with gr.Accordion("Show banner", open=False):
            gr.Image(
                value=WELCOME_IMAGE_PATH,
                show_label=False,
                interactive=False,
                container=False,
                height=200
            )
    gr.ChatInterface(
        fn=reply,
        title="😺 Cats of Ancient Egypt Chatbot 😺",
        chatbot=chatbot_component,  # already defined above
        type="messages",
    )


# -------- Launch (SSR off + Spaces-friendly) --------
if __name__ == "__main__":
    demo.launch(
        ssr_mode=False,
        server_name="0.0.0.0",
        server_port=int(os.getenv("PORT", "7860")),
        share=False,
        allowed_paths=["assets"],
    )