gracewidj's picture
switch to gemini pro
89ef36b verified
import os
import json
import urllib.request
import urllib.error
from pathlib import Path
from collections import defaultdict
import gradio as gr
import pandas as pd
import torch
import torch.nn.functional as F
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer
# -----------------------------
# Model load
# -----------------------------
ROOT = Path(__file__).parent
MODEL_DIR = ROOT / "models" / "SM3_binary_model"
device = "cuda" if torch.cuda.is_available() else "cpu"
model = DistilBertForSequenceClassification.from_pretrained(MODEL_DIR)
tokenizer = DistilBertTokenizer.from_pretrained(MODEL_DIR)
model.to(device)
model.eval()
id2label = model.config.id2label # {0:'negative', 1:'positive'}
def predict_one(text: str):
enc = tokenizer(
text,
return_tensors="pt",
truncation=True,
padding=True,
max_length=128
).to(device)
with torch.no_grad():
out = model(**enc)
probs = F.softmax(out.logits, dim=-1).squeeze(0).detach().cpu().tolist()
pred_id = int(torch.argmax(torch.tensor(probs)).item())
label = id2label[pred_id]
conf = float(probs[pred_id])
return label, conf
# -----------------------------
# Cheap feature extraction (keywords)
# -----------------------------
FEATURE_KEYWORDS = {
"Print Speed": ["speed", "fast", "quick", "slow"],
"Print Quality": ["quality", "sharp", "clear", "blur", "smudge", "colour", "color"],
"Reliability": ["reliable", "consistent", "durable", "broke", "broken", "stopped", "fault", "issue", "jam", "jams"],
"Ease of Use": ["easy", "setup", "install", "installation", "simple", "user-friendly"],
"Connectivity": ["wifi", "wireless", "bluetooth", "connection", "disconnect", "network"],
"Noise": ["noisy", "loud", "quiet"],
"Value for Money": ["value", "worth", "price", "expensive", "cheap", "cost"],
"Toner/Ink Cost": ["toner", "ink", "cartridge", "refill"],
}
def extract_features(text: str):
t = (text or "").lower()
hits = []
for feat, kws in FEATURE_KEYWORDS.items():
if any(k in t for k in kws):
hits.append(feat)
return hits
def build_feature_tables(df_out: pd.DataFrame):
stats = defaultdict(lambda: {"mentions": 0, "pos": 0, "neg": 0, "conf_sum": 0.0})
for _, row in df_out.iterrows():
review = str(row["review"])
sent = str(row["sentiment"]).lower()
conf = float(row["confidence"])
feats = extract_features(review)
for f in feats:
stats[f]["mentions"] += 1
stats[f]["conf_sum"] += conf
if sent == "positive":
stats[f]["pos"] += 1
else:
stats[f]["neg"] += 1
rows = []
for f, s in stats.items():
m = s["mentions"]
if m == 0:
continue
pos_pct = s["pos"] / m * 100
neg_pct = s["neg"] / m * 100
avg_conf = s["conf_sum"] / m
rating = 1 + 4 * (pos_pct / 100.0)
rows.append({
"feature": f,
"mentions": m,
"positive_%": round(pos_pct, 1),
"negative_%": round(neg_pct, 1),
"avg_conf": round(avg_conf, 3),
"rating_1to5": round(rating, 2),
})
feat_df = pd.DataFrame(rows).sort_values(by=["mentions", "rating_1to5"], ascending=[False, False])
if feat_df.empty:
feat_df = pd.DataFrame(columns=["feature", "mentions", "positive_%", "negative_%", "avg_conf", "rating_1to5"])
love_df = feat_df.sort_values(by=["positive_%", "mentions"], ascending=[False, False]).head(5).copy()
love_df = love_df[["feature", "positive_%", "mentions"]]
love_df.columns = ["theme", "positive_%", "mentions"]
concern_df = feat_df.sort_values(by=["negative_%", "mentions"], ascending=[False, False]).head(5).copy()
concern_df = concern_df[["feature", "negative_%", "mentions"]]
concern_df.columns = ["theme", "negative_%", "mentions"]
return love_df, concern_df, feat_df
# -----------------------------
# Gemini REST (reliable in Spaces)
# -----------------------------
def gemini_ready():
return bool(os.environ.get("GEMINI_API_KEY", "").strip())
def gemini_generate_insights(history_reviews: list, df_out: pd.DataFrame) -> str:
key = os.environ.get("GEMINI_API_KEY", "").strip()
if not key:
return "Gemini not configured: missing GEMINI_API_KEY secret."
pos = (df_out["sentiment"].str.lower() == "positive").sum()
neg = len(df_out) - pos
pos_examples = df_out[df_out["sentiment"].str.lower() == "positive"]["review"].head(5).tolist()
neg_examples = df_out[df_out["sentiment"].str.lower() == "negative"]["review"].head(5).tolist()
prompt = f"""
Overall sentiment: Positive={pos}, Negative={neg}, Total={len(df_out)}
Positive examples:
{chr(10).join([f"- {x}" for x in pos_examples])}
Negative examples:
{chr(10).join([f"- {x}" for x in neg_examples])}
Write: summary, loves, concerns, improvements (concise).
""".strip()
url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent"
payload = {"contents": [{"role": "user", "parts": [{"text": prompt}]}]}
req = urllib.request.Request(
url + f"?key={key}",
data=json.dumps(payload).encode("utf-8"),
headers={"Content-Type": "application/json"},
method="POST",
)
try:
with urllib.request.urlopen(req, timeout=25) as resp:
body = resp.read().decode("utf-8")
data = json.loads(body)
text = (
data.get("candidates", [{}])[0]
.get("content", {})
.get("parts", [{}])[0]
.get("text", "")
)
return text.strip() or f"Gemini returned no text. Raw: {body[:200]}"
except urllib.error.HTTPError as e:
detail = e.read().decode("utf-8") if hasattr(e, "read") else str(e)
return f"Gemini HTTPError {e.code}: {detail[:300]}"
except Exception as e:
return f"Gemini failed: {type(e).__name__}: {e}"
# -----------------------------
# Stateful app logic (history)
# -----------------------------
def submit_and_accumulate(new_text: str, history: list):
history = history or []
new_reviews = [r.strip() for r in (new_text or "").splitlines() if r.strip()]
history.extend(new_reviews)
if not history:
empty = pd.DataFrame(columns=["review", "sentiment", "confidence"])
return history, "", "No input.", empty, "", pd.DataFrame(), pd.DataFrame(), pd.DataFrame(), ""
# Predict on full history
out_rows = []
for r in history:
label, conf = predict_one(r)
out_rows.append({"review": r, "sentiment": label, "confidence": round(conf, 3)})
df_out = pd.DataFrame(out_rows)
pos = (df_out["sentiment"].str.lower() == "positive").sum()
neg = len(df_out) - pos
overall = f"Positive: {pos} | Negative: {neg} | Total: {len(df_out)}"
breakdown = f"- Positive: {pos} ({pos/len(df_out)*100:.1f}%)\n- Negative: {neg} ({neg/len(df_out)*100:.1f}%)"
love_df, concern_df, feat_df = build_feature_tables(df_out)
history_text = "\n".join(history)
return history, history_text, overall, df_out, breakdown, love_df, concern_df, feat_df, ""
def run_gemini_from_history(history: list):
if not history:
return ""
out_rows = []
for r in history:
label, conf = predict_one(r)
out_rows.append({"review": r, "sentiment": label, "confidence": round(conf, 3)})
df_out = pd.DataFrame(out_rows)
return gemini_generate_insights(history, df_out)
def clear_all():
empty = pd.DataFrame(columns=["review", "sentiment", "confidence"])
return [], "", "No input.", empty, "", pd.DataFrame(), pd.DataFrame(), pd.DataFrame(), ""
# -----------------------------
# UI
# -----------------------------
with gr.Blocks() as demo:
gr.Markdown("# SM3: Customer Review Sentiment Analysis")
gr.Markdown(
"Submit reviews (one per line). Each submit **adds to history** and updates the dashboard. "
"Optional: generate **AI Insights (Gemini)** from accumulated reviews."
)
history_state = gr.State([])
inp = gr.Textbox(label="New Reviews (one per line)", lines=6, placeholder="Paste new reviews here...")
with gr.Row():
submit = gr.Button("Submit")
genai_btn = gr.Button("Generate AI Insights (Gemini)")
clear = gr.Button("Clear All")
history_box = gr.Textbox(label="History (all submitted reviews)", lines=6)
overall = gr.Textbox(label="Overall Sentiment")
table = gr.Dataframe(label="Per-review Results", wrap=True)
breakdown = gr.Markdown()
gr.Markdown("## What Users Love")
love_table = gr.Dataframe(label="Top Positive Themes", wrap=True)
gr.Markdown("## Common Concerns")
concern_table = gr.Dataframe(label="Top Negative Themes", wrap=True)
gr.Markdown("## Feature Ratings (keyword-based)")
feat_table = gr.Dataframe(label="Feature Ratings", wrap=True)
ai_box = gr.Markdown(label="AI Insights (Gemini)")
submit.click(
submit_and_accumulate,
inputs=[inp, history_state],
outputs=[history_state, history_box, overall, table, breakdown, love_table, concern_table, feat_table, ai_box]
)
genai_btn.click(
run_gemini_from_history,
inputs=[history_state],
outputs=[ai_box]
)
clear.click(
clear_all,
outputs=[history_state, history_box, overall, table, breakdown, love_table, concern_table, feat_table, ai_box]
)
demo.launch()