Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import re | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForSequenceClassification, | |
| MarianMTModel, | |
| MarianTokenizer, | |
| ) | |
| import numpy as np | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # MODEL PATHS | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| FINBERT_PATH = "./models/finbert-finetuned" | |
| TRANSLATE_MODEL = "Helsinki-NLP/opus-mt-tr-en" | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # LOAD MODELS (cached after first run) | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| print("Loading FinBERT model...") | |
| try: | |
| finbert_tokenizer = AutoTokenizer.from_pretrained(FINBERT_PATH) | |
| finbert_model = AutoModelForSequenceClassification.from_pretrained(FINBERT_PATH) | |
| finbert_model.eval() | |
| FINBERT_LABELS = list(finbert_model.config.id2label.values()) | |
| except Exception as e: | |
| print(f"[WARN] Could not load local FinBERT, falling back to ProsusAI/finbert: {e}") | |
| finbert_tokenizer = AutoTokenizer.from_pretrained("ProsusAI/finbert") | |
| finbert_model = AutoModelForSequenceClassification.from_pretrained("ProsusAI/finbert") | |
| finbert_model.eval() | |
| FINBERT_LABELS = ["positive", "negative", "neutral"] | |
| print("Loading translation model...") | |
| tr_tokenizer = MarianTokenizer.from_pretrained(TRANSLATE_MODEL) | |
| tr_model = MarianMTModel.from_pretrained(TRANSLATE_MODEL) | |
| tr_model.eval() | |
| print("All models loaded.") | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # FINANCIAL KEYWORDS (EN) | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| FINANCIAL_KEYWORDS = [ | |
| "revenue", "profit", "loss", "earnings", "growth", "decline", "risk", | |
| "investment", "market", "stock", "bond", "interest", "rate", "inflation", | |
| "debt", "equity", "dividend", "volatility", "forecast", "outlook", | |
| "recession", "expansion", "gdp", "cash", "flow", "asset", "liability", | |
| "bankruptcy", "merger", "acquisition", "ipo", "shares", "fund", | |
| ] | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # HELPERS | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| def detect_language(text: str) -> str: | |
| """Simple heuristic: Turkish-specific characters β 'tr', else 'en'.""" | |
| tr_chars = set("Γ§ΔΔ±ΓΆΕΓΌΓΔΔ°ΓΕΓ") | |
| if any(c in tr_chars for c in text): | |
| return "tr" | |
| turkish_words = {"ve", "bir", "bu", "ile", "iΓ§in", "da", "de", "den", "nin", | |
| "nΔ±n", "nun", "nΓΌn", "Δ±n", "in", "un", "ΓΌn", "yΔ±", "yi", | |
| "yu", "yΓΌ", "ta", "te", "tan", "ten"} | |
| words = set(text.lower().split()) | |
| if len(words & turkish_words) >= 2: | |
| return "tr" | |
| return "en" | |
| def translate_tr_to_en(text: str) -> str: | |
| inputs = tr_tokenizer([text], return_tensors="pt", padding=True, truncation=True, max_length=512) | |
| with torch.no_grad(): | |
| translated = tr_model.generate(**inputs) | |
| return tr_tokenizer.decode(translated[0], skip_special_tokens=True) | |
| def extract_keywords(text: str) -> list[str]: | |
| words = re.findall(r'\b\w+\b', text.lower()) | |
| found = [w for w in words if w in FINANCIAL_KEYWORDS] | |
| return list(dict.fromkeys(found)) # deduplicate, preserve order | |
| def get_risk_level(label: str, confidence: float) -> str: | |
| label = label.lower() | |
| if label == "negative": | |
| if confidence >= 0.80: | |
| return "π΄ HIGH RISK" | |
| elif confidence >= 0.55: | |
| return "π MEDIUM RISK" | |
| else: | |
| return "π‘ LOW-MEDIUM RISK" | |
| elif label == "positive": | |
| if confidence >= 0.80: | |
| return "π’ LOW RISK" | |
| else: | |
| return "π‘ LOW-MEDIUM RISK" | |
| else: | |
| return "π‘ NEUTRAL / MONITOR" | |
| def run_finbert(text: str): | |
| inputs = finbert_tokenizer(text, return_tensors="pt", truncation=True, | |
| max_length=512, padding=True) | |
| with torch.no_grad(): | |
| outputs = finbert_model(**inputs) | |
| probs = torch.softmax(outputs.logits, dim=-1).squeeze().numpy() | |
| idx = int(np.argmax(probs)) | |
| label = FINBERT_LABELS[idx] | |
| confidence = float(probs[idx]) | |
| return label, confidence, probs | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # MAIN PREDICT FUNCTION | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| def analyze(text: str): | |
| if not text or not text.strip(): | |
| return "β οΈ Please enter some text.", "", "", "", "" | |
| lang = detect_language(text) | |
| original_text = text | |
| if lang == "tr": | |
| translated_text = translate_tr_to_en(text) | |
| lang_info = f"π Detected: **Turkish** β translated to English" | |
| else: | |
| translated_text = text | |
| lang_info = "π Detected: **English**" | |
| label, confidence, all_probs = run_finbert(translated_text) | |
| risk = get_risk_level(label, confidence) | |
| keywords = extract_keywords(translated_text) | |
| sentiment_emoji = {"positive": "π", "negative": "π", "neutral": "β‘οΈ"} | |
| emoji = sentiment_emoji.get(label.lower(), "β") | |
| label_display = f"{emoji} {label.upper()}" | |
| confidence_display = f"{confidence*100:.1f}%" | |
| keywords_display = ", ".join(keywords) if keywords else "β" | |
| # Build score breakdown | |
| scores_md = "\n".join( | |
| [f"- **{FINBERT_LABELS[i]}**: {all_probs[i]*100:.1f}%" | |
| for i in range(len(FINBERT_LABELS))] | |
| ) | |
| translation_note = ( | |
| f"\n\n**Translated text:** _{translated_text}_" | |
| if lang == "tr" else "" | |
| ) | |
| summary = ( | |
| f"{lang_info}{translation_note}\n\n" | |
| f"### Score Breakdown\n{scores_md}" | |
| ) | |
| return label_display, confidence_display, risk, keywords_display, summary | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # GRADIO UI | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Blocks( | |
| title="Financial Sentiment Analysis API", | |
| theme=gr.themes.Soft(primary_hue="blue"), | |
| css=""" | |
| .result-box { border-radius: 8px; padding: 8px; } | |
| footer { display: none !important; } | |
| """, | |
| ) as demo: | |
| gr.Markdown( | |
| """ | |
| # π Financial Sentiment Analysis | |
| ### Powered by FinBERT Β· Supports Turkish & English | |
| Paste any financial news headline, earnings summary, or analyst comment. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| text_input = gr.Textbox( | |
| label="π Input Text (Turkish or English)", | |
| placeholder="e.g. 'Company reported record profits this quarter' or 'Εirket bu Γ§eyrekte rekor kar aΓ§Δ±kladΔ±'", | |
| lines=5, | |
| ) | |
| submit_btn = gr.Button("π Analyze Sentiment", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| out_label = gr.Textbox(label="Sentiment Label", elem_classes="result-box") | |
| out_confidence = gr.Textbox(label="Confidence Score", elem_classes="result-box") | |
| out_risk = gr.Textbox(label="Risk Level", elem_classes="result-box") | |
| out_keywords = gr.Textbox(label="Financial Keywords", elem_classes="result-box") | |
| out_summary = gr.Markdown(label="Details") | |
| submit_btn.click( | |
| fn=analyze, | |
| inputs=[text_input], | |
| outputs=[out_label, out_confidence, out_risk, out_keywords, out_summary], | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["The company reported a significant drop in quarterly earnings due to supply chain disruptions."], | |
| ["Strong revenue growth and expanding margins signal a bullish outlook for investors."], | |
| ["Εirketin hisse senetleri, beklentilerin ΓΌzerinde kar aΓ§Δ±klamasΔ±nΔ±n ardΔ±ndan yΓΌkseldi."], | |
| ["Merkez bankasΔ± faiz oranlarΔ±nΔ± artΔ±rarak enflasyonla mΓΌcadele etmeye devam ediyor."], | |
| ["Markets remained flat as investors awaited the Federal Reserve's rate decision."], | |
| ], | |
| inputs=text_input, | |
| label="π Example Inputs", | |
| ) | |
| gr.Markdown( | |
| """ | |
| --- | |
| **Model:** Fine-tuned FinBERT for financial sentiment classification | |
| **Translation:** Helsinki-NLP/opus-mt-tr-en for TurkishβEnglish | |
| **Labels:** Positive Β· Negative Β· Neutral | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |