Liva21 commited on
Commit
709f9b2
Β·
1 Parent(s): 592e430

Initial deploy

Browse files
Files changed (2) hide show
  1. app.py +226 -0
  2. requirements.txt +6 -25
app.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import re
4
+ from transformers import (
5
+ AutoTokenizer,
6
+ AutoModelForSequenceClassification,
7
+ MarianMTModel,
8
+ MarianTokenizer,
9
+ )
10
+ import numpy as np
11
+
12
+ # ─────────────────────────────────────────────
13
+ # MODEL PATHS
14
+ # ─────────────────────────────────────────────
15
+ FINBERT_PATH = "./models/finbert-finetuned"
16
+ TRANSLATE_MODEL = "Helsinki-NLP/opus-mt-tr-en"
17
+
18
+ # ─────────────────────────────────────────────
19
+ # LOAD MODELS (cached after first run)
20
+ # ─────────────────────────────────────────────
21
+ print("Loading FinBERT model...")
22
+ try:
23
+ finbert_tokenizer = AutoTokenizer.from_pretrained(FINBERT_PATH)
24
+ finbert_model = AutoModelForSequenceClassification.from_pretrained(FINBERT_PATH)
25
+ finbert_model.eval()
26
+ FINBERT_LABELS = list(finbert_model.config.id2label.values())
27
+ except Exception as e:
28
+ print(f"[WARN] Could not load local FinBERT, falling back to ProsusAI/finbert: {e}")
29
+ finbert_tokenizer = AutoTokenizer.from_pretrained("ProsusAI/finbert")
30
+ finbert_model = AutoModelForSequenceClassification.from_pretrained("ProsusAI/finbert")
31
+ finbert_model.eval()
32
+ FINBERT_LABELS = ["positive", "negative", "neutral"]
33
+
34
+ print("Loading translation model...")
35
+ tr_tokenizer = MarianTokenizer.from_pretrained(TRANSLATE_MODEL)
36
+ tr_model = MarianMTModel.from_pretrained(TRANSLATE_MODEL)
37
+ tr_model.eval()
38
+ print("All models loaded.")
39
+
40
+ # ─────────────────────────────────────────────
41
+ # FINANCIAL KEYWORDS (EN)
42
+ # ─────────────────────────────────────────────
43
+ FINANCIAL_KEYWORDS = [
44
+ "revenue", "profit", "loss", "earnings", "growth", "decline", "risk",
45
+ "investment", "market", "stock", "bond", "interest", "rate", "inflation",
46
+ "debt", "equity", "dividend", "volatility", "forecast", "outlook",
47
+ "recession", "expansion", "gdp", "cash", "flow", "asset", "liability",
48
+ "bankruptcy", "merger", "acquisition", "ipo", "shares", "fund",
49
+ ]
50
+
51
+ # ─────────────────────────────────────────────
52
+ # HELPERS
53
+ # ─────────────────────────────────────────────
54
+
55
+ def detect_language(text: str) -> str:
56
+ """Simple heuristic: Turkish-specific characters β†’ 'tr', else 'en'."""
57
+ tr_chars = set("Γ§ΔŸΔ±ΓΆΕŸΓΌΓ‡ΔžΔ°Γ–ΕžΓœ")
58
+ if any(c in tr_chars for c in text):
59
+ return "tr"
60
+ turkish_words = {"ve", "bir", "bu", "ile", "iΓ§in", "da", "de", "den", "nin",
61
+ "nΔ±n", "nun", "nΓΌn", "Δ±n", "in", "un", "ΓΌn", "yΔ±", "yi",
62
+ "yu", "yΓΌ", "ta", "te", "tan", "ten"}
63
+ words = set(text.lower().split())
64
+ if len(words & turkish_words) >= 2:
65
+ return "tr"
66
+ return "en"
67
+
68
+
69
+ def translate_tr_to_en(text: str) -> str:
70
+ inputs = tr_tokenizer([text], return_tensors="pt", padding=True, truncation=True, max_length=512)
71
+ with torch.no_grad():
72
+ translated = tr_model.generate(**inputs)
73
+ return tr_tokenizer.decode(translated[0], skip_special_tokens=True)
74
+
75
+
76
+ def extract_keywords(text: str) -> list[str]:
77
+ words = re.findall(r'\b\w+\b', text.lower())
78
+ found = [w for w in words if w in FINANCIAL_KEYWORDS]
79
+ return list(dict.fromkeys(found)) # deduplicate, preserve order
80
+
81
+
82
+ def get_risk_level(label: str, confidence: float) -> str:
83
+ label = label.lower()
84
+ if label == "negative":
85
+ if confidence >= 0.80:
86
+ return "πŸ”΄ HIGH RISK"
87
+ elif confidence >= 0.55:
88
+ return "🟠 MEDIUM RISK"
89
+ else:
90
+ return "🟑 LOW-MEDIUM RISK"
91
+ elif label == "positive":
92
+ if confidence >= 0.80:
93
+ return "🟒 LOW RISK"
94
+ else:
95
+ return "🟑 LOW-MEDIUM RISK"
96
+ else:
97
+ return "🟑 NEUTRAL / MONITOR"
98
+
99
+
100
+ def run_finbert(text: str):
101
+ inputs = finbert_tokenizer(text, return_tensors="pt", truncation=True,
102
+ max_length=512, padding=True)
103
+ with torch.no_grad():
104
+ outputs = finbert_model(**inputs)
105
+ probs = torch.softmax(outputs.logits, dim=-1).squeeze().numpy()
106
+ idx = int(np.argmax(probs))
107
+ label = FINBERT_LABELS[idx]
108
+ confidence = float(probs[idx])
109
+ return label, confidence, probs
110
+
111
+
112
+ # ─────────────────────────────────────────────
113
+ # MAIN PREDICT FUNCTION
114
+ # ─────────────────────────────────────────────
115
+
116
+ def analyze(text: str):
117
+ if not text or not text.strip():
118
+ return "⚠️ Please enter some text.", "", "", "", ""
119
+
120
+ lang = detect_language(text)
121
+ original_text = text
122
+
123
+ if lang == "tr":
124
+ translated_text = translate_tr_to_en(text)
125
+ lang_info = f"🌐 Detected: **Turkish** β†’ translated to English"
126
+ else:
127
+ translated_text = text
128
+ lang_info = "🌐 Detected: **English**"
129
+
130
+ label, confidence, all_probs = run_finbert(translated_text)
131
+ risk = get_risk_level(label, confidence)
132
+ keywords = extract_keywords(translated_text)
133
+
134
+ sentiment_emoji = {"positive": "πŸ“ˆ", "negative": "πŸ“‰", "neutral": "➑️"}
135
+ emoji = sentiment_emoji.get(label.lower(), "❓")
136
+
137
+ label_display = f"{emoji} {label.upper()}"
138
+ confidence_display = f"{confidence*100:.1f}%"
139
+ keywords_display = ", ".join(keywords) if keywords else "β€”"
140
+
141
+ # Build score breakdown
142
+ scores_md = "\n".join(
143
+ [f"- **{FINBERT_LABELS[i]}**: {all_probs[i]*100:.1f}%"
144
+ for i in range(len(FINBERT_LABELS))]
145
+ )
146
+
147
+ translation_note = (
148
+ f"\n\n**Translated text:** _{translated_text}_"
149
+ if lang == "tr" else ""
150
+ )
151
+
152
+ summary = (
153
+ f"{lang_info}{translation_note}\n\n"
154
+ f"### Score Breakdown\n{scores_md}"
155
+ )
156
+
157
+ return label_display, confidence_display, risk, keywords_display, summary
158
+
159
+
160
+ # ─────────────────────────────────────────────
161
+ # GRADIO UI
162
+ # ─────────────────────────────────────────────
163
+
164
+ with gr.Blocks(
165
+ title="Financial Sentiment Analysis API",
166
+ theme=gr.themes.Soft(primary_hue="blue"),
167
+ css="""
168
+ .result-box { border-radius: 8px; padding: 8px; }
169
+ footer { display: none !important; }
170
+ """,
171
+ ) as demo:
172
+
173
+ gr.Markdown(
174
+ """
175
+ # πŸ“Š Financial Sentiment Analysis
176
+ ### Powered by FinBERT Β· Supports Turkish & English
177
+ Paste any financial news headline, earnings summary, or analyst comment.
178
+ """
179
+ )
180
+
181
+ with gr.Row():
182
+ with gr.Column(scale=2):
183
+ text_input = gr.Textbox(
184
+ label="πŸ“ Input Text (Turkish or English)",
185
+ placeholder="e.g. 'Company reported record profits this quarter' or 'Şirket bu çeyrekte rekor kar açıkladı'",
186
+ lines=5,
187
+ )
188
+ submit_btn = gr.Button("πŸ” Analyze Sentiment", variant="primary", size="lg")
189
+
190
+ with gr.Column(scale=1):
191
+ out_label = gr.Textbox(label="Sentiment Label", elem_classes="result-box")
192
+ out_confidence = gr.Textbox(label="Confidence Score", elem_classes="result-box")
193
+ out_risk = gr.Textbox(label="Risk Level", elem_classes="result-box")
194
+ out_keywords = gr.Textbox(label="Financial Keywords", elem_classes="result-box")
195
+
196
+ out_summary = gr.Markdown(label="Details")
197
+
198
+ submit_btn.click(
199
+ fn=analyze,
200
+ inputs=[text_input],
201
+ outputs=[out_label, out_confidence, out_risk, out_keywords, out_summary],
202
+ )
203
+
204
+ gr.Examples(
205
+ examples=[
206
+ ["The company reported a significant drop in quarterly earnings due to supply chain disruptions."],
207
+ ["Strong revenue growth and expanding margins signal a bullish outlook for investors."],
208
+ ["Şirketin hisse senetleri, beklentilerin üzerinde kar açıklamasının ardından yükseldi."],
209
+ ["Merkez bankasΔ± faiz oranlarΔ±nΔ± artΔ±rarak enflasyonla mΓΌcadele etmeye devam ediyor."],
210
+ ["Markets remained flat as investors awaited the Federal Reserve's rate decision."],
211
+ ],
212
+ inputs=text_input,
213
+ label="πŸ“Œ Example Inputs",
214
+ )
215
+
216
+ gr.Markdown(
217
+ """
218
+ ---
219
+ **Model:** Fine-tuned FinBERT for financial sentiment classification
220
+ **Translation:** Helsinki-NLP/opus-mt-tr-en for Turkish→English
221
+ **Labels:** Positive Β· Negative Β· Neutral
222
+ """
223
+ )
224
+
225
+ if __name__ == "__main__":
226
+ demo.launch()
requirements.txt CHANGED
@@ -1,25 +1,6 @@
1
- torch>=2.2.0
2
- transformers>=4.40.0
3
- datasets==2.16.0
4
- scikit-learn>=1.3.2
5
- pandas>=2.1.4
6
- numpy<2.0.0
7
- matplotlib>=3.8.2
8
- seaborn==0.13.0
9
- jupyter==1.0.0
10
- ipykernel==6.27.1
11
- fastapi==0.109.0
12
- uvicorn[standard]==0.27.0
13
- python-dotenv==1.0.0
14
- pydantic>=2.5.3
15
- accelerate==0.25.0
16
- langdetect==1.0.9
17
- sentencepiece
18
- streamlit==1.31.0
19
- plotly==5.18.0
20
- sentencepiece==0.1.99
21
- sacremoses==0.0.53
22
- feedparser==6.0.11
23
- schedule==1.2.1
24
- beautifulsoup4==4.12.3
25
- pytest==7.4.4
 
1
+ gradio>=4.0.0
2
+ torch>=2.0.0
3
+ transformers>=4.35.0
4
+ sentencepiece>=0.1.99
5
+ sacremoses>=0.0.53
6
+ numpy>=1.24.0