Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import urllib.request | |
| import urllib.error | |
| from pathlib import Path | |
| from collections import defaultdict | |
| import gradio as gr | |
| import pandas as pd | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import DistilBertForSequenceClassification, DistilBertTokenizer | |
| # ----------------------------- | |
| # Model load | |
| # ----------------------------- | |
| ROOT = Path(__file__).parent | |
| MODEL_DIR = ROOT / "models" / "SM3_binary_model" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = DistilBertForSequenceClassification.from_pretrained(MODEL_DIR) | |
| tokenizer = DistilBertTokenizer.from_pretrained(MODEL_DIR) | |
| model.to(device) | |
| model.eval() | |
| id2label = model.config.id2label # {0:'negative', 1:'positive'} | |
| def predict_one(text: str): | |
| enc = tokenizer( | |
| text, | |
| return_tensors="pt", | |
| truncation=True, | |
| padding=True, | |
| max_length=128 | |
| ).to(device) | |
| with torch.no_grad(): | |
| out = model(**enc) | |
| probs = F.softmax(out.logits, dim=-1).squeeze(0).detach().cpu().tolist() | |
| pred_id = int(torch.argmax(torch.tensor(probs)).item()) | |
| label = id2label[pred_id] | |
| conf = float(probs[pred_id]) | |
| return label, conf | |
| # ----------------------------- | |
| # Cheap feature extraction (keywords) | |
| # ----------------------------- | |
| FEATURE_KEYWORDS = { | |
| "Print Speed": ["speed", "fast", "quick", "slow"], | |
| "Print Quality": ["quality", "sharp", "clear", "blur", "smudge", "colour", "color"], | |
| "Reliability": ["reliable", "consistent", "durable", "broke", "broken", "stopped", "fault", "issue", "jam", "jams"], | |
| "Ease of Use": ["easy", "setup", "install", "installation", "simple", "user-friendly"], | |
| "Connectivity": ["wifi", "wireless", "bluetooth", "connection", "disconnect", "network"], | |
| "Noise": ["noisy", "loud", "quiet"], | |
| "Value for Money": ["value", "worth", "price", "expensive", "cheap", "cost"], | |
| "Toner/Ink Cost": ["toner", "ink", "cartridge", "refill"], | |
| } | |
| def extract_features(text: str): | |
| t = (text or "").lower() | |
| hits = [] | |
| for feat, kws in FEATURE_KEYWORDS.items(): | |
| if any(k in t for k in kws): | |
| hits.append(feat) | |
| return hits | |
| def build_feature_tables(df_out: pd.DataFrame): | |
| stats = defaultdict(lambda: {"mentions": 0, "pos": 0, "neg": 0, "conf_sum": 0.0}) | |
| for _, row in df_out.iterrows(): | |
| review = str(row["review"]) | |
| sent = str(row["sentiment"]).lower() | |
| conf = float(row["confidence"]) | |
| feats = extract_features(review) | |
| for f in feats: | |
| stats[f]["mentions"] += 1 | |
| stats[f]["conf_sum"] += conf | |
| if sent == "positive": | |
| stats[f]["pos"] += 1 | |
| else: | |
| stats[f]["neg"] += 1 | |
| rows = [] | |
| for f, s in stats.items(): | |
| m = s["mentions"] | |
| if m == 0: | |
| continue | |
| pos_pct = s["pos"] / m * 100 | |
| neg_pct = s["neg"] / m * 100 | |
| avg_conf = s["conf_sum"] / m | |
| rating = 1 + 4 * (pos_pct / 100.0) | |
| rows.append({ | |
| "feature": f, | |
| "mentions": m, | |
| "positive_%": round(pos_pct, 1), | |
| "negative_%": round(neg_pct, 1), | |
| "avg_conf": round(avg_conf, 3), | |
| "rating_1to5": round(rating, 2), | |
| }) | |
| feat_df = pd.DataFrame(rows).sort_values(by=["mentions", "rating_1to5"], ascending=[False, False]) | |
| if feat_df.empty: | |
| feat_df = pd.DataFrame(columns=["feature", "mentions", "positive_%", "negative_%", "avg_conf", "rating_1to5"]) | |
| love_df = feat_df.sort_values(by=["positive_%", "mentions"], ascending=[False, False]).head(5).copy() | |
| love_df = love_df[["feature", "positive_%", "mentions"]] | |
| love_df.columns = ["theme", "positive_%", "mentions"] | |
| concern_df = feat_df.sort_values(by=["negative_%", "mentions"], ascending=[False, False]).head(5).copy() | |
| concern_df = concern_df[["feature", "negative_%", "mentions"]] | |
| concern_df.columns = ["theme", "negative_%", "mentions"] | |
| return love_df, concern_df, feat_df | |
| # ----------------------------- | |
| # Gemini REST (reliable in Spaces) | |
| # ----------------------------- | |
| def gemini_ready(): | |
| return bool(os.environ.get("GEMINI_API_KEY", "").strip()) | |
| def gemini_generate_insights(history_reviews: list, df_out: pd.DataFrame) -> str: | |
| key = os.environ.get("GEMINI_API_KEY", "").strip() | |
| if not key: | |
| return "Gemini not configured: missing GEMINI_API_KEY secret." | |
| pos = (df_out["sentiment"].str.lower() == "positive").sum() | |
| neg = len(df_out) - pos | |
| pos_examples = df_out[df_out["sentiment"].str.lower() == "positive"]["review"].head(5).tolist() | |
| neg_examples = df_out[df_out["sentiment"].str.lower() == "negative"]["review"].head(5).tolist() | |
| prompt = f""" | |
| Overall sentiment: Positive={pos}, Negative={neg}, Total={len(df_out)} | |
| Positive examples: | |
| {chr(10).join([f"- {x}" for x in pos_examples])} | |
| Negative examples: | |
| {chr(10).join([f"- {x}" for x in neg_examples])} | |
| Write: summary, loves, concerns, improvements (concise). | |
| """.strip() | |
| url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent" | |
| payload = {"contents": [{"role": "user", "parts": [{"text": prompt}]}]} | |
| req = urllib.request.Request( | |
| url + f"?key={key}", | |
| data=json.dumps(payload).encode("utf-8"), | |
| headers={"Content-Type": "application/json"}, | |
| method="POST", | |
| ) | |
| try: | |
| with urllib.request.urlopen(req, timeout=25) as resp: | |
| body = resp.read().decode("utf-8") | |
| data = json.loads(body) | |
| text = ( | |
| data.get("candidates", [{}])[0] | |
| .get("content", {}) | |
| .get("parts", [{}])[0] | |
| .get("text", "") | |
| ) | |
| return text.strip() or f"Gemini returned no text. Raw: {body[:200]}" | |
| except urllib.error.HTTPError as e: | |
| detail = e.read().decode("utf-8") if hasattr(e, "read") else str(e) | |
| return f"Gemini HTTPError {e.code}: {detail[:300]}" | |
| except Exception as e: | |
| return f"Gemini failed: {type(e).__name__}: {e}" | |
| # ----------------------------- | |
| # Stateful app logic (history) | |
| # ----------------------------- | |
| def submit_and_accumulate(new_text: str, history: list): | |
| history = history or [] | |
| new_reviews = [r.strip() for r in (new_text or "").splitlines() if r.strip()] | |
| history.extend(new_reviews) | |
| if not history: | |
| empty = pd.DataFrame(columns=["review", "sentiment", "confidence"]) | |
| return history, "", "No input.", empty, "", pd.DataFrame(), pd.DataFrame(), pd.DataFrame(), "" | |
| # Predict on full history | |
| out_rows = [] | |
| for r in history: | |
| label, conf = predict_one(r) | |
| out_rows.append({"review": r, "sentiment": label, "confidence": round(conf, 3)}) | |
| df_out = pd.DataFrame(out_rows) | |
| pos = (df_out["sentiment"].str.lower() == "positive").sum() | |
| neg = len(df_out) - pos | |
| overall = f"Positive: {pos} | Negative: {neg} | Total: {len(df_out)}" | |
| breakdown = f"- Positive: {pos} ({pos/len(df_out)*100:.1f}%)\n- Negative: {neg} ({neg/len(df_out)*100:.1f}%)" | |
| love_df, concern_df, feat_df = build_feature_tables(df_out) | |
| history_text = "\n".join(history) | |
| return history, history_text, overall, df_out, breakdown, love_df, concern_df, feat_df, "" | |
| def run_gemini_from_history(history: list): | |
| if not history: | |
| return "" | |
| out_rows = [] | |
| for r in history: | |
| label, conf = predict_one(r) | |
| out_rows.append({"review": r, "sentiment": label, "confidence": round(conf, 3)}) | |
| df_out = pd.DataFrame(out_rows) | |
| return gemini_generate_insights(history, df_out) | |
| def clear_all(): | |
| empty = pd.DataFrame(columns=["review", "sentiment", "confidence"]) | |
| return [], "", "No input.", empty, "", pd.DataFrame(), pd.DataFrame(), pd.DataFrame(), "" | |
| # ----------------------------- | |
| # UI | |
| # ----------------------------- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# SM3: Customer Review Sentiment Analysis") | |
| gr.Markdown( | |
| "Submit reviews (one per line). Each submit **adds to history** and updates the dashboard. " | |
| "Optional: generate **AI Insights (Gemini)** from accumulated reviews." | |
| ) | |
| history_state = gr.State([]) | |
| inp = gr.Textbox(label="New Reviews (one per line)", lines=6, placeholder="Paste new reviews here...") | |
| with gr.Row(): | |
| submit = gr.Button("Submit") | |
| genai_btn = gr.Button("Generate AI Insights (Gemini)") | |
| clear = gr.Button("Clear All") | |
| history_box = gr.Textbox(label="History (all submitted reviews)", lines=6) | |
| overall = gr.Textbox(label="Overall Sentiment") | |
| table = gr.Dataframe(label="Per-review Results", wrap=True) | |
| breakdown = gr.Markdown() | |
| gr.Markdown("## What Users Love") | |
| love_table = gr.Dataframe(label="Top Positive Themes", wrap=True) | |
| gr.Markdown("## Common Concerns") | |
| concern_table = gr.Dataframe(label="Top Negative Themes", wrap=True) | |
| gr.Markdown("## Feature Ratings (keyword-based)") | |
| feat_table = gr.Dataframe(label="Feature Ratings", wrap=True) | |
| ai_box = gr.Markdown(label="AI Insights (Gemini)") | |
| submit.click( | |
| submit_and_accumulate, | |
| inputs=[inp, history_state], | |
| outputs=[history_state, history_box, overall, table, breakdown, love_table, concern_table, feat_table, ai_box] | |
| ) | |
| genai_btn.click( | |
| run_gemini_from_history, | |
| inputs=[history_state], | |
| outputs=[ai_box] | |
| ) | |
| clear.click( | |
| clear_all, | |
| outputs=[history_state, history_box, overall, table, breakdown, love_table, concern_table, feat_table, ai_box] | |
| ) | |
| demo.launch() |