sandbox_dev / app.py
igortech's picture
Update app.py
b4686d0 verified
import json
import os
import re
import csv
import tempfile
from rapidfuzz import fuzz
import datetime
import gradio as gr
# -----------------------------
# Config / data loading
# -----------------------------
DATA_PATH = "quotes.json"
def load_quotes():
if os.path.exists(DATA_PATH):
try:
with open(DATA_PATH, "r", encoding="utf-8") as f:
data = json.load(f)
if isinstance(data, dict):
print(f"Loaded dataset from {DATA_PATH} with {len(data)} categories.")
return data
except Exception as e:
print(f"Failed to load {DATA_PATH}: {e}")
print("No dataset file found. Upload one via the UI.")
return {}
QUOTES = load_quotes()
# -----------------------------
# Text helpers
# -----------------------------
STOPWORDS = {
"the","a","an","and","or","but","if","then","so","than","to","of","in","on","at","for",
"is","are","was","were","be","being","been","it","that","this","these","those","with",
"as","by","from","about","into","over","after","before","up","down","out"
}
POS_HINTS = {"good","great","love","like","enjoy","awesome","amazing","nice","positive","best","fantastic","excellent"}
NEG_HINTS = {"bad","hate","dislike","worst","awful","terrible","negative","poor","meh","gross","unsafe","hard","difficult"}
punct_re = re.compile(r"[{}]".format(re.escape("""!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~""")))
def normalize(text: str) -> str:
return punct_re.sub(" ", (text or "").lower())
def tokenize(text: str):
return [t for t in normalize(text).split() if t and t not in STOPWORDS]
def infer_sentiment(user_text: str) -> str:
tl = normalize(user_text)
has_pos = any(w in tl for w in POS_HINTS)
has_neg = any(w in tl for w in NEG_HINTS)
if has_pos and not has_neg:
return "positive"
if has_neg and not has_pos:
return "negative"
return "positive"
# -----------------------------
# Retrieval
# -----------------------------
def best_match_quote(user_text: str) -> str:
best_score = 0
best_quote = None
for cat, quotes_list in QUOTES.items():
for quote_entry in quotes_list:
quote = quote_entry.get("quote", "")
if not quote.strip():
continue
score = fuzz.partial_ratio(user_text.lower(), quote.lower())
if score > best_score:
best_score = score
best_quote = quote
if best_score < 30 or best_quote is None:
return f"No data about '{user_text}'"
return best_quote
# -----------------------------
# Gradio callbacks
# -----------------------------
def respond(message, history, category):
if not QUOTES:
bot = "No dataset loaded. Please upload a JSON file first."
history.append({"role": "user", "content": message})
history.append({"role": "assistant", "content": bot})
return "", history
if not category:
bot = "Please select a category."
history.append({"role": "user", "content": message})
history.append({"role": "assistant", "content": bot})
return "", history
quote = best_match_quote(message)
# 3-fold response
summary = quote.split(". ")[0] + "." if "." in quote else quote
detail = quote
unknown = ""
if "No data about" in quote:
unknown = quote
bot_text = f"Summary:\n{summary}\n\nWhat real people say:\n{detail}"
if unknown:
bot_text += f"\n\n{unknown}"
history.append({"role": "user", "content": message})
history.append({"role": "assistant", "content": bot_text})
return "", history
def clear_chat():
return None
def upload_json(filepath):
global QUOTES, DATA_PATH
try:
with open(filepath, "r", encoding="utf-8") as f:
data = json.load(f)
if not isinstance(data, dict):
return gr.update(value="Upload failed: JSON root must be an object."), gr.update(choices=[])
QUOTES = data
DATA_PATH = os.path.basename(filepath)
cats = sorted(list(QUOTES.keys()))
status = f"Loaded {len(cats)} categories from {DATA_PATH}."
return status, gr.update(choices=cats, value=(cats[0] if cats else None))
except Exception as e:
return f"Error loading file: {e}", gr.update(choices=[])
def download_current_json():
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".json")
with open(tmp_file.name, "w", encoding="utf-8") as f:
json.dump(QUOTES, f, indent=2, ensure_ascii=False)
return tmp_file.name
def download_conversation_csv(history):
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
with open(tmp_file.name, "w", newline="", encoding="utf-8") as f:
writer = csv.writer(f)
writer.writerow(["role", "message"])
for msg in history:
writer.writerow([msg.get("role"), msg.get("content")])
return tmp_file.name
# -----------------------------
# UI
# -----------------------------
with gr.Blocks() as demo:
gr.Markdown("## 🎓 College Life Chatbot — Category-Aware, Fuzzy Matching")
initial_categories = sorted(list(QUOTES.keys()))
with gr.Row():
category = gr.Dropdown(
label="Category",
choices=initial_categories,
value=(initial_categories[0] if initial_categories else None)
)
chatbot = gr.Chatbot(label="Conversation", height=360, type="messages")
msg = gr.Textbox(label="Your message", placeholder="Ask something like: 'Is food good in college?'", autofocus=True)
send = gr.Button("Send")
clear = gr.Button("Clear")
with gr.Row():
upload_btn = gr.File(label="Upload dataset (.json)", file_types=[".json"], type="filepath")
upload_status = gr.Textbox(label="Upload status", interactive=False)
# New download system
with gr.Row():
download_json_btn = gr.Button("Download current dataset (.json)")
download_csv_btn = gr.Button("Export conversation to CSV")
download_json_file = gr.File(label="JSON download")
download_csv_file = gr.File(label="CSV download")
# Events
msg.submit(respond, [msg, chatbot, category], [msg, chatbot])
send.click(respond, [msg, chatbot, category], [msg, chatbot])
clear.click(clear_chat, None, chatbot, queue=False)
upload_btn.upload(upload_json, upload_btn, [upload_status, category])
# Fixed download events using Button -> File
download_json_btn.click(fn=download_current_json, inputs=None, outputs=download_json_file)
download_csv_btn.click(fn=download_conversation_csv, inputs=chatbot, outputs=download_csv_file)
# -----------------------------
# Startup log
# -----------------------------
print(f"===== Application Startup at {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')} =====")
if QUOTES:
for cat, entries in QUOTES.items():
print(f" - {cat}: {len(entries)} entries")
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)