File size: 7,037 Bytes
3e79950
b2c8e1d
2170185
 
 
fbd0e58
35eb385
b4cdfab
3e79950
a18d57c
c03a84c
a18d57c
3e79950
 
35eb385
2170185
35eb385
 
 
 
 
 
 
 
 
 
 
 
b4cdfab
116ad34
35eb385
116ad34
35eb385
 
 
 
 
92c739f
35eb385
 
92c739f
fa16b47
92c739f
35eb385
 
2e57be7
35eb385
 
92c739f
35eb385
 
 
 
 
 
 
 
 
92c739f
35eb385
 
 
 
fa16b47
35eb385
fa16b47
 
 
 
 
 
 
 
 
 
35eb385
 
92c739f
2170185
35eb385
2170185
35eb385
fa16b47
 
 
 
 
 
 
 
 
 
35eb385
fa16b47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35eb385
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa16b47
 
35eb385
fa16b47
35eb385
92c739f
fa16b47
 
 
 
 
 
 
bd690cf
a18d57c
35eb385
a18d57c
f008056
fa16b47
19601ea
35eb385
19601ea
 
35eb385
 
 
 
 
b2c8e1d
35eb385
 
 
 
19601ea
35eb385
fa16b47
35eb385
b4686d0
 
 
 
 
 
 
19601ea
fa16b47
35eb385
 
 
fa16b47
d89e92d
b4686d0
 
 
19601ea
35eb385
 
 
 
 
fa16b47
 
2170185
f008056
35eb385
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import json
import os
import re
import csv
import tempfile
from rapidfuzz import fuzz
import datetime
import gradio as gr

# -----------------------------
# Config / data loading
# -----------------------------
DATA_PATH = "quotes.json"

def load_quotes():
    if os.path.exists(DATA_PATH):
        try:
            with open(DATA_PATH, "r", encoding="utf-8") as f:
                data = json.load(f)
            if isinstance(data, dict):
                print(f"Loaded dataset from {DATA_PATH} with {len(data)} categories.")
                return data
        except Exception as e:
            print(f"Failed to load {DATA_PATH}: {e}")
    print("No dataset file found. Upload one via the UI.")
    return {}

QUOTES = load_quotes()

# -----------------------------
# Text helpers
# -----------------------------
STOPWORDS = {
    "the","a","an","and","or","but","if","then","so","than","to","of","in","on","at","for",
    "is","are","was","were","be","being","been","it","that","this","these","those","with",
    "as","by","from","about","into","over","after","before","up","down","out"
}

POS_HINTS = {"good","great","love","like","enjoy","awesome","amazing","nice","positive","best","fantastic","excellent"}
NEG_HINTS = {"bad","hate","dislike","worst","awful","terrible","negative","poor","meh","gross","unsafe","hard","difficult"}

punct_re = re.compile(r"[{}]".format(re.escape("""!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~""")))

def normalize(text: str) -> str:
    return punct_re.sub(" ", (text or "").lower())

def tokenize(text: str):
    return [t for t in normalize(text).split() if t and t not in STOPWORDS]

def infer_sentiment(user_text: str) -> str:
    tl = normalize(user_text)
    has_pos = any(w in tl for w in POS_HINTS)
    has_neg = any(w in tl for w in NEG_HINTS)
    if has_pos and not has_neg:
        return "positive"
    if has_neg and not has_pos:
        return "negative"
    return "positive"

# -----------------------------
# Retrieval
# -----------------------------
def best_match_quote(user_text: str) -> str:
    best_score = 0
    best_quote = None
    for cat, quotes_list in QUOTES.items():
        for quote_entry in quotes_list:
            quote = quote_entry.get("quote", "")
            if not quote.strip():
                continue
            score = fuzz.partial_ratio(user_text.lower(), quote.lower())
            if score > best_score:
                best_score = score
                best_quote = quote
    if best_score < 30 or best_quote is None:
        return f"No data about '{user_text}'"
    return best_quote

# -----------------------------
# Gradio callbacks
# -----------------------------
def respond(message, history, category):
    if not QUOTES:
        bot = "No dataset loaded. Please upload a JSON file first."
        history.append({"role": "user", "content": message})
        history.append({"role": "assistant", "content": bot})
        return "", history

    if not category:
        bot = "Please select a category."
        history.append({"role": "user", "content": message})
        history.append({"role": "assistant", "content": bot})
        return "", history

    quote = best_match_quote(message)

    # 3-fold response
    summary = quote.split(". ")[0] + "." if "." in quote else quote
    detail = quote
    unknown = ""
    if "No data about" in quote:
        unknown = quote

    bot_text = f"Summary:\n{summary}\n\nWhat real people say:\n{detail}"
    if unknown:
        bot_text += f"\n\n{unknown}"

    history.append({"role": "user", "content": message})
    history.append({"role": "assistant", "content": bot_text})
    return "", history

def clear_chat():
    return None

def upload_json(filepath):
    global QUOTES, DATA_PATH
    try:
        with open(filepath, "r", encoding="utf-8") as f:
            data = json.load(f)
        if not isinstance(data, dict):
            return gr.update(value="Upload failed: JSON root must be an object."), gr.update(choices=[])
        QUOTES = data
        DATA_PATH = os.path.basename(filepath)
        cats = sorted(list(QUOTES.keys()))
        status = f"Loaded {len(cats)} categories from {DATA_PATH}."
        return status, gr.update(choices=cats, value=(cats[0] if cats else None))
    except Exception as e:
        return f"Error loading file: {e}", gr.update(choices=[])

def download_current_json():
    tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".json")
    with open(tmp_file.name, "w", encoding="utf-8") as f:
        json.dump(QUOTES, f, indent=2, ensure_ascii=False)
    return tmp_file.name

def download_conversation_csv(history):
    tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
    with open(tmp_file.name, "w", newline="", encoding="utf-8") as f:
        writer = csv.writer(f)
        writer.writerow(["role", "message"])
        for msg in history:
            writer.writerow([msg.get("role"), msg.get("content")])
    return tmp_file.name

# -----------------------------
# UI
# -----------------------------
with gr.Blocks() as demo:
    gr.Markdown("## 🎓 College Life Chatbot — Category-Aware, Fuzzy Matching")

    initial_categories = sorted(list(QUOTES.keys()))

    with gr.Row():
        category = gr.Dropdown(
            label="Category",
            choices=initial_categories,
            value=(initial_categories[0] if initial_categories else None)
        )

    chatbot = gr.Chatbot(label="Conversation", height=360, type="messages")
    msg = gr.Textbox(label="Your message", placeholder="Ask something like: 'Is food good in college?'", autofocus=True)
    send = gr.Button("Send")
    clear = gr.Button("Clear")

    with gr.Row():
        upload_btn = gr.File(label="Upload dataset (.json)", file_types=[".json"], type="filepath")
        upload_status = gr.Textbox(label="Upload status", interactive=False)

    # New download system
    with gr.Row():
        download_json_btn = gr.Button("Download current dataset (.json)")
        download_csv_btn = gr.Button("Export conversation to CSV")
        download_json_file = gr.File(label="JSON download")
        download_csv_file = gr.File(label="CSV download")

    # Events
    msg.submit(respond, [msg, chatbot, category], [msg, chatbot])
    send.click(respond, [msg, chatbot, category], [msg, chatbot])
    clear.click(clear_chat, None, chatbot, queue=False)
    upload_btn.upload(upload_json, upload_btn, [upload_status, category])

    # Fixed download events using Button -> File
    download_json_btn.click(fn=download_current_json, inputs=None, outputs=download_json_file)
    download_csv_btn.click(fn=download_conversation_csv, inputs=chatbot, outputs=download_csv_file)

# -----------------------------
# Startup log
# -----------------------------
print(f"===== Application Startup at {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')} =====")
if QUOTES:
    for cat, entries in QUOTES.items():
        print(f" - {cat}: {len(entries)} entries")

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)