Spaces:
Sleeping
Sleeping
| import re | |
| import gradio as gr | |
| from transformers import pipeline | |
| MODEL_ID = "JanhaviS14/finance-sentiment-mini-finbert" | |
| clf = pipeline( | |
| "text-classification", | |
| model=MODEL_ID, | |
| tokenizer=MODEL_ID, | |
| top_k=3, | |
| ) | |
| # ----------------------------- | |
| # Negation guardrail | |
| # ----------------------------- | |
| NEG = r"(?:not|no|never|didn't|did not|isn't|is not|wasn't|was not|can't|cannot|won't|will not)" | |
| POS_CUES = r"(?:grow|increase|improve|expand|rise|surge|beat|strong|record|higher|profit|profits|pat|earnings|revenue|sales|margin|margins)" | |
| def has_negated_positive_cue(text: str) -> bool: | |
| t = text.lower() | |
| patterns = [ | |
| rf"{NEG}\s+(?:much\s+)?{POS_CUES}", # "did not grow", "not increase" | |
| rf"not\s+much\s+(?:of\s+an?\s+)?{POS_CUES}", # "not much increase" | |
| rf"{NEG}\s+(?:any\s+)?(?:significant\s+)?{POS_CUES}", | |
| ] | |
| return any(re.search(p, t) for p in patterns) | |
| def apply_negation_guardrail(level: str, dist: dict, text: str) -> str: | |
| # Never allow strong positive when the sentence negates growth/increase/profit cues | |
| if has_negated_positive_cue(text) and level in {"positive", "moderately positive"}: | |
| if float(dist.get("negative", 0.0)) >= 0.25: | |
| return "moderately negative" | |
| return "neutral" | |
| return level | |
| # ----------------------------- | |
| # Clause-level aggregation (robust for mixed statements) | |
| # ----------------------------- | |
| CONTRAST_SPLIT = re.compile( | |
| r"\b(?:but|while|however|although|though|yet|despite|whereas)\b|[;:\n]", | |
| re.IGNORECASE | |
| ) | |
| INTENSITY = { | |
| "modest": 0.7, "modestly": 0.7, "slight": 0.7, "slightly": 0.7, | |
| "marginal": 0.7, "marginally": 0.7, | |
| "strong": 1.2, "strongly": 1.2, | |
| "significant": 1.3, "significantly": 1.3, | |
| "sharp": 1.4, "sharply": 1.4, | |
| "record": 1.4, | |
| "material": 1.3, "materially": 1.3, | |
| } | |
| # Profitability / margins typically dominate revenue in finance sentiment | |
| SIGNAL_WEIGHT = [ | |
| (re.compile(r"\b(profit after tax|pat|net profit|profit|earnings|eps|margin|margins)\b", re.IGNORECASE), 1.25), | |
| (re.compile(r"\b(revenue|sales)\b", re.IGNORECASE), 1.00), | |
| (re.compile(r"\b(cost|costs|expense|expenses|inflation)\b", re.IGNORECASE), 1.10), | |
| (re.compile(r"\b(debt|default|liquidity|cash flow)\b", re.IGNORECASE), 1.15), | |
| ] | |
| def split_clauses(text: str): | |
| parts = [p.strip() for p in CONTRAST_SPLIT.split(text) if p and p.strip()] | |
| return parts if parts else [text.strip()] | |
| def clause_weight(clause: str) -> float: | |
| w = 1.0 | |
| low = clause.lower() | |
| # intensity modifier | |
| for k, mult in INTENSITY.items(): | |
| if k in low: | |
| w *= mult | |
| break | |
| # signal importance modifier | |
| for pat, mult in SIGNAL_WEIGHT: | |
| if pat.search(clause): | |
| w *= mult | |
| break | |
| return w | |
| def dist_from_scores(scores): | |
| scores = sorted(scores, key=lambda x: x["score"], reverse=True) | |
| dist = {s["label"]: float(s["score"]) for s in scores} | |
| top_label = scores[0]["label"] | |
| top_score = float(scores[0]["score"]) | |
| return dist, top_label, top_score | |
| def map_direction_to_5_level(direction: float, confident: bool) -> str: | |
| """ | |
| direction = positive_prob - negative_prob (roughly in [-1, 1]) | |
| """ | |
| if not confident: | |
| if abs(direction) < 0.10: | |
| return "neutral" | |
| return "moderately positive" if direction > 0 else "moderately negative" | |
| if direction >= 0.35: | |
| return "positive" | |
| if direction >= 0.12: | |
| return "moderately positive" | |
| if direction <= -0.35: | |
| return "negative" | |
| if direction <= -0.12: | |
| return "moderately negative" | |
| return "neutral" | |
| def aggregate_5_level(text: str) -> str: | |
| clauses = split_clauses(text) | |
| evidence = [] | |
| for c in clauses: | |
| out = clf(c)[0] # top_k=3 returns list of dicts | |
| dist, _, top_prob = dist_from_scores(out) | |
| direction = float(dist.get("positive", 0.0)) - float(dist.get("negative", 0.0)) | |
| w = clause_weight(c) | |
| evidence.append((c, dist, direction, w, top_prob)) | |
| total_w = sum(e[3] for e in evidence) if evidence else 1.0 | |
| agg_direction = sum(e[2] * e[3] for e in evidence) / total_w | |
| # Mixed evidence detection: separate positive and negative signals across clauses | |
| has_pos = any(e[2] > 0.12 for e in evidence) | |
| has_neg = any(e[2] < -0.12 for e in evidence) | |
| # Confidence / ambiguity proxy from clause-level top probs + aggregate direction clarity | |
| # If most clauses are low-confidence, treat as not confident | |
| avg_top_prob = sum(e[4] for e in evidence) / len(evidence) if evidence else 0.0 | |
| confident = avg_top_prob >= 0.55 # align with your earlier threshold | |
| # If mixed, avoid strong labels; return moderate leaning | |
| if has_pos and has_neg: | |
| if agg_direction > 0.10: | |
| return "moderately positive" | |
| if agg_direction < -0.10: | |
| return "moderately negative" | |
| return "neutral" | |
| return map_direction_to_5_level(agg_direction, confident) | |
| # ----------------------------- | |
| # Predict | |
| # ----------------------------- | |
| def predict(text: str): | |
| text = (text or "").strip() | |
| if not text: | |
| return {"error": "Please enter a sentence."} | |
| # Clause-level aggregated label | |
| level = aggregate_5_level(text) | |
| # Full-sentence distribution for guardrails + confidence display | |
| full_scores = clf(text)[0] | |
| full_scores = sorted(full_scores, key=lambda x: x["score"], reverse=True) | |
| dist = {s["label"]: float(s["score"]) for s in full_scores} | |
| top_conf = float(full_scores[0]["score"]) | |
| # Negation guardrail | |
| level = apply_negation_guardrail(level, dist, text) | |
| return { | |
| "assessment": level, | |
| "confidence": round(top_conf, 4), | |
| } | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Textbox( | |
| label="Financial sentence / news headline", | |
| placeholder="e.g., Company reports record quarterly profit, shares jump 8%", | |
| lines=3, | |
| ), | |
| outputs=gr.JSON(label="Prediction"), | |
| title="Finance Sentiment Mini (FinBERT fine-tuned)", | |
| description=( | |
| "Finance-focused sentiment analysis using FinBERT fine-tuned on Financial PhraseBank. " | |
| "Outputs a 5-level assessment with clause-aware handling for mixed statements." | |
| ), | |
| examples=[ | |
| ["The company reported record quarterly profits and raised its full-year guidance."], | |
| ["Revenue did not grow, while there was not much increase in profit after tax."], | |
| ["Revenue grew modestly, while profit after tax decreased."], | |
| ["Despite revenue growth, rising costs kept margins under pressure."], | |
| ["The firm withdrew guidance due to uncertainty in demand conditions."], | |
| ["The company disclosed changes in shareholding structure."], | |
| ], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(ssr_mode=False) |