| from typing import List, Dict, Any, Optional |
| from fastapi import FastAPI, HTTPException |
| from pydantic import BaseModel, Field |
| import torch |
| import torch.nn.functional as F |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification |
| import os |
|
|
|
|
|
|
| MODEL_DIR = os.getenv("MODEL_DIR", "./sentiment_model") |
| LABEL_MAP = {0: "positive", 1: "negative", 2: "neutral", 3: "irrelevant"} |
| REVERSE_LABEL_MAP = {v: k for k, v in LABEL_MAP.items()} |
| MAX_LEN = 128 |
|
|
| |
| INFLUENCER_STATEMENTS = { |
| "elon_musk": [ |
| "Tesla is pushing the boundaries of sustainable energy.", |
| "We might need to slow down hiring this quarter.", |
| "Starship test looked promising but we've got lots to fix.", |
| "The new features are amazing—huge improvements coming soon!", |
| ], |
| "vitalik_buterin": [ |
| "Layer-2 scaling is crucial for mainstream adoption.", |
| "Fees remain a challenge; we must keep iterating.", |
| "Great progress from the community this month!", |
| "Beware of scams claiming guaranteed returns.", |
| ], |
| "kanye_west": [ |
| "My art defines the culture.", |
| "Haters will always talk, but the vision is bigger.", |
| "New project dropping—it's a masterpiece.", |
| "Some people just don't understand the mission yet.", |
| ], |
| } |
|
|
|
|
| app = FastAPI(title="Influencer Sentiment API", version="1.0.0") |
|
|
|
|
| class AnalyzeInfluencerIn(BaseModel): |
| influencer: str = Field(..., description="Influencer key, e.g. 'elon_musk'") |
|
|
|
|
| class AnalyzeTextIn(BaseModel): |
| text: str = Field(..., description="A single statement for analysis") |
|
|
|
|
| class BatchAnalyzeIn(BaseModel): |
| statements: List[str] = Field(..., description="List of statements to analyze") |
| influencer: Optional[str] = Field(None, description="Optional influencer name for context") |
|
|
|
|
|
|
| try: |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR) |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR) |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model.to(device) |
| model.eval() |
| except Exception as e: |
| raise RuntimeError(f"Failed to load model from {MODEL_DIR}: {e}") |
|
|
|
|
|
|
| @torch.inference_mode() |
| def predict(texts: List[str]) -> List[Dict[str, Any]]: |
| """Returns list of dicts: {label, score, probs (dict[label->prob])}.""" |
| if isinstance(texts, str): |
| texts = [texts] |
|
|
| enc = tokenizer( |
| texts, |
| truncation=True, |
| padding=True, |
| max_length=MAX_LEN, |
| return_tensors="pt" |
| ) |
| enc = {k: v.to(device) for k, v in enc.items()} |
|
|
| logits = model(**enc).logits |
| probs = F.softmax(logits, dim=-1) |
| confs, preds = torch.max(probs, dim=-1) |
|
|
| out = [] |
| for i in range(len(texts)): |
| label_id = preds[i].item() |
| score = confs[i].item() |
| prob_row = probs[i].tolist() |
| prob_dict = {LABEL_MAP[j]: float(prob_row[j]) for j in range(len(prob_row))} |
| out.append({ |
| "label": LABEL_MAP[label_id], |
| "score": float(score), |
| "probs": prob_dict, |
| }) |
| return out |
|
|
|
|
|
|
| def generate_summary(results: List[Dict[str, Any]], influencer: Optional[str] = None) -> str: |
| total = len(results) |
| if total == 0: |
| return "No statements analyzed." |
|
|
| counts: Dict[str, int] = {lbl: 0 for lbl in LABEL_MAP.values()} |
| avg_conf = 0.0 |
| for r in results: |
| counts[r["label"]] += 1 |
| avg_conf += r["score"] |
| avg_conf /= total |
|
|
|
|
| dominant = max(counts.items(), key=lambda kv: kv[1])[0] |
|
|
| |
| tone_map = { |
| "positive": "mostly positive / optimistic", |
| "negative": "mostly negative / critical", |
| "neutral": "mixed or neutral", |
| "irrelevant": "largely off-topic or not sentiment-bearing", |
| } |
|
|
| name = influencer or "This influencer" |
| summary = ( |
| f"{name} shows {tone_map.get(dominant, dominant)} tone " |
| f"({counts[dominant]}/{total}). " |
| f"Confidence ~{avg_conf:.2f}. " |
| ) |
|
|
|
|
| if dominant == "positive": |
| summary += "Likely to drive favorable audience reactions and engagement." |
| elif dominant == "negative": |
| summary += "Expect pushback; consider addressing concerns proactively." |
| elif dominant == "neutral": |
| summary += "Messaging is balanced; highlight clearer value props to move sentiment." |
| else: |
| summary += "Many statements lack sentiment; focus on clearer, value-forward messages." |
|
|
| return summary |
|
|
|
|
| |
|
|
| @app.get("/health") |
| def health() -> Dict[str, str]: |
| return {"status": "ok"} |
|
|
|
|
| @app.post("/analyze") |
| def analyze_influencer(body: AnalyzeInfluencerIn) -> Dict[str, Any]: |
| key = body.influencer.strip().lower() |
| if key not in INFLUENCER_STATEMENTS: |
| raise HTTPException(status_code=404, |
| detail=f"Unknown influencer '{body.influencer}'. Available: {list(INFLUENCER_STATEMENTS)}") |
|
|
| texts = INFLUENCER_STATEMENTS[key] |
| preds = predict(texts) |
|
|
| results = [ |
| {"text": t, "label": p["label"], "score": p["score"], "probs": p["probs"]} |
| for t, p in zip(texts, preds) |
| ] |
|
|
| return { |
| "influencer": body.influencer, |
| "results": results, |
| "summary": generate_summary(results, influencer=body.influencer), |
| "distribution": { |
| lbl: sum(1 for r in results if r["label"] == lbl) for lbl in LABEL_MAP.values() |
| }, |
| } |
|
|
|
|
| @app.post("/analyze_text") |
| def analyze_text(body: AnalyzeTextIn) -> Dict[str, Any]: |
| p = predict([body.text])[0] |
| return { |
| "text": body.text, |
| "label": p["label"], |
| "score": p["score"], |
| "probs": p["probs"], |
| } |
|
|
|
|
| @app.post("/batch_analyze") |
| def batch_analyze(body: BatchAnalyzeIn) -> Dict[str, Any]: |
| if not body.statements: |
| raise HTTPException(status_code=400, detail="No statements provided.") |
|
|
| preds = predict(body.statements) |
| results = [ |
| {"text": t, "label": p["label"], "score": p["score"], "probs": p["probs"]} |
| for t, p in zip(body.statements, preds) |
| ] |
|
|
| return { |
| "count": len(results), |
| "results": results, |
| "summary": generate_summary(results, influencer=body.influencer), |
| "distribution": { |
| lbl: sum(1 for r in results if r["label"] == lbl) for lbl in LABEL_MAP.values() |
| }, |
| } |
|
|