codealchemist01 commited on
Commit
20fd239
Β·
verified Β·
1 Parent(s): f35ea6d

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +473 -234
app.py CHANGED
@@ -1,234 +1,473 @@
1
- #!/usr/bin/env python3
2
- """
3
- Hugging Face Space App for Financial Sentiment Analysis Ensemble
4
- """
5
-
6
- import gradio as gr
7
- import torch
8
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
9
- import numpy as np
10
- from datetime import datetime
11
- import json
12
-
13
- class FinancialSentimentEnsemble:
14
- def __init__(self):
15
- self.models = {}
16
- self.tokenizers = {}
17
- self.model_names = [
18
- "codealchemist01/financial-sentiment-distilbert",
19
- "codealchemist01/financial-sentiment-bert-large",
20
- "codealchemist01/financial-sentiment-improved"
21
- ]
22
- self.labels = ["Bearish πŸ“‰", "Neutral ➑️", "Bullish πŸ“ˆ"]
23
- self.load_models()
24
-
25
- def load_models(self):
26
- """Load all models and tokenizers"""
27
- print("πŸš€ Loading Financial Sentiment Analysis Ensemble...")
28
-
29
- for i, model_name in enumerate(self.model_names):
30
- try:
31
- print(f"πŸ“₯ Loading {model_name}...")
32
- self.tokenizers[i] = AutoTokenizer.from_pretrained(model_name)
33
- self.models[i] = AutoModelForSequenceClassification.from_pretrained(model_name)
34
- self.models[i].eval()
35
- print(f"βœ… {model_name} loaded successfully!")
36
- except Exception as e:
37
- print(f"❌ Error loading {model_name}: {e}")
38
-
39
- print(f"πŸŽ‰ Ensemble ready with {len(self.models)} models!")
40
-
41
- def predict_single_model(self, text, model_idx):
42
- """Predict sentiment using a single model"""
43
- if model_idx not in self.models:
44
- return None
45
-
46
- try:
47
- inputs = self.tokenizers[model_idx](
48
- text,
49
- return_tensors="pt",
50
- truncation=True,
51
- padding=True,
52
- max_length=512
53
- )
54
-
55
- with torch.no_grad():
56
- outputs = self.models[model_idx](**inputs)
57
- probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
58
-
59
- return probabilities[0].numpy()
60
- except Exception as e:
61
- print(f"Error in model {model_idx}: {e}")
62
- return None
63
-
64
- def predict_ensemble(self, text):
65
- """Predict sentiment using ensemble of all models"""
66
- if not text.strip():
67
- return "Please enter some text to analyze.", {}, {}
68
-
69
- individual_predictions = {}
70
- all_probabilities = []
71
-
72
- # Get predictions from each model
73
- for i, model_name in enumerate(self.model_names):
74
- probs = self.predict_single_model(text, i)
75
- if probs is not None:
76
- all_probabilities.append(probs)
77
-
78
- # Individual model results
79
- predicted_class = np.argmax(probs)
80
- confidence = probs[predicted_class]
81
-
82
- model_short_name = model_name.split("/")[-1].replace("financial-sentiment-", "").title()
83
- individual_predictions[f"{model_short_name}"] = {
84
- "Prediction": self.labels[predicted_class],
85
- "Confidence": f"{confidence:.1%}"
86
- }
87
-
88
- if not all_probabilities:
89
- return "Error: No models available for prediction.", {}, {}
90
-
91
- # Ensemble prediction (average probabilities)
92
- ensemble_probs = np.mean(all_probabilities, axis=0)
93
- ensemble_prediction = np.argmax(ensemble_probs)
94
- ensemble_confidence = ensemble_probs[ensemble_prediction]
95
-
96
- # Create probability distribution for visualization
97
- prob_dict = {}
98
- for i, label in enumerate(self.labels):
99
- prob_dict[label] = float(ensemble_probs[i])
100
-
101
- # Result summary
102
- result_text = f"""
103
- ## 🎯 Ensemble Prediction: **{self.labels[ensemble_prediction]}**
104
- **Confidence:** {ensemble_confidence:.1%}
105
-
106
- ### πŸ“Š Probability Distribution:
107
- - πŸ“‰ Bearish: {ensemble_probs[0]:.1%}
108
- - ➑️ Neutral: {ensemble_probs[1]:.1%}
109
- - πŸ“ˆ Bullish: {ensemble_probs[2]:.1%}
110
-
111
- ### πŸ€– Individual Model Results:
112
- """
113
-
114
- for model_name, result in individual_predictions.items():
115
- result_text += f"- **{model_name}**: {result['Prediction']} ({result['Confidence']})\n"
116
-
117
- return result_text, prob_dict, individual_predictions
118
-
119
- # Initialize the ensemble
120
- ensemble = FinancialSentimentEnsemble()
121
-
122
- def analyze_sentiment(text):
123
- """Main function for Gradio interface"""
124
- return ensemble.predict_ensemble(text)
125
-
126
- # Example texts for demonstration
127
- examples = [
128
- "The stock market is showing strong bullish momentum with record highs across major indices.",
129
- "Company earnings fell short of expectations, leading to a significant drop in share price.",
130
- "The Federal Reserve maintained interest rates, keeping market conditions stable.",
131
- "Tesla's innovative battery technology could revolutionize the automotive industry.",
132
- "Rising inflation concerns are creating uncertainty in the financial markets.",
133
- "The merger announcement sent both companies' stock prices soaring.",
134
- "Quarterly results were mixed, with some sectors outperforming while others lagged."
135
- ]
136
-
137
- # Create Gradio interface
138
- with gr.Blocks(
139
- theme=gr.themes.Soft(),
140
- title="Financial Sentiment Analysis Ensemble",
141
- css="""
142
- .gradio-container {
143
- max-width: 1200px !important;
144
- }
145
- .main-header {
146
- text-align: center;
147
- margin-bottom: 2rem;
148
- }
149
- .model-info {
150
- background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
151
- color: white;
152
- padding: 1rem;
153
- border-radius: 10px;
154
- margin: 1rem 0;
155
- }
156
- """
157
- ) as demo:
158
-
159
- gr.HTML("""
160
- <div class="main-header">
161
- <h1>πŸš€ Financial Sentiment Analysis Ensemble</h1>
162
- <p>Advanced AI-powered sentiment analysis for financial texts using an ensemble of 3 fine-tuned models</p>
163
- </div>
164
- """)
165
-
166
- with gr.Row():
167
- with gr.Column(scale=2):
168
- text_input = gr.Textbox(
169
- label="πŸ“ Enter Financial Text",
170
- placeholder="Type or paste financial news, social media posts, or market commentary here...",
171
- lines=4,
172
- max_lines=10
173
- )
174
-
175
- analyze_btn = gr.Button("πŸ” Analyze Sentiment", variant="primary", size="lg")
176
-
177
- gr.Examples(
178
- examples=examples,
179
- inputs=text_input,
180
- label="πŸ’‘ Try these examples:"
181
- )
182
-
183
- with gr.Column(scale=3):
184
- result_output = gr.Markdown(label="πŸ“Š Analysis Results")
185
-
186
- with gr.Row():
187
- prob_plot = gr.BarPlot(
188
- x="Sentiment",
189
- y="Probability",
190
- title="Ensemble Probability Distribution",
191
- x_title="Sentiment Categories",
192
- y_title="Probability",
193
- width=400,
194
- height=300
195
- )
196
-
197
- individual_results = gr.JSON(
198
- label="πŸ€– Individual Model Predictions",
199
- visible=True
200
- )
201
-
202
- # Model Information
203
- gr.HTML("""
204
- <div class="model-info">
205
- <h3>🧠 Ensemble Models:</h3>
206
- <ul>
207
- <li><strong>DistilBERT Model:</strong> Fast and efficient, optimized for real-time analysis</li>
208
- <li><strong>BERT-Large Model:</strong> High accuracy with deep contextual understanding</li>
209
- <li><strong>Improved Model:</strong> Enhanced with advanced training techniques</li>
210
- </ul>
211
- <p><strong>Ensemble Accuracy:</strong> 79.7% | <strong>Categories:</strong> Bearish πŸ“‰, Neutral ➑️, Bullish πŸ“ˆ</p>
212
- </div>
213
- """)
214
-
215
- # Event handlers
216
- analyze_btn.click(
217
- fn=analyze_sentiment,
218
- inputs=text_input,
219
- outputs=[result_output, prob_plot, individual_results]
220
- )
221
-
222
- text_input.submit(
223
- fn=analyze_sentiment,
224
- inputs=text_input,
225
- outputs=[result_output, prob_plot, individual_results]
226
- )
227
-
228
- # Launch the app
229
- if __name__ == "__main__":
230
- demo.launch(
231
- server_name="0.0.0.0",
232
- server_port=7860,
233
- share=False
234
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Financial Sentiment Analysis - Enhanced Ensemble Gradio Demo for Hugging Face Space
4
+ 3-Model Ensemble System with Rule Engine
5
+ """
6
+
7
+ import gradio as gr
8
+ import torch
9
+ import numpy as np
10
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
11
+ import logging
12
+ import re
13
+ from typing import Dict, List, Tuple
14
+
15
+ # Logging setup
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
+
19
+ class SentimentRuleEngine:
20
+ """Rule-based post-processing for sentiment analysis"""
21
+
22
+ def __init__(self):
23
+ # Strong bullish keywords with weights
24
+ self.bullish_keywords = {
25
+ 'soaring': 0.9, 'skyrocketing': 0.9, 'surging': 0.9, 'exploding': 0.9,
26
+ 'excellent': 0.8, 'outstanding': 0.8, 'exceptional': 0.8, 'amazing': 0.8,
27
+ 'breakthrough': 0.8, 'revolutionary': 0.8, 'record-breaking': 0.9,
28
+ 'all-time high': 0.9, 'new high': 0.8, 'moon': 0.8, 'rocket': 0.8,
29
+ 'mooning': 0.9, 'rocketing': 0.8, 'booming': 0.7, 'thriving': 0.7,
30
+ 'up 10%': 0.8, 'up 15%': 0.9, 'up 20%': 0.9, 'gained 10%': 0.8,
31
+ 'rose 15%': 0.8, 'jumped 20%': 0.9, 'spiked': 0.8, 'surged': 0.8,
32
+ 'rising': 0.6, 'climbing': 0.6, 'gaining': 0.6, 'growing': 0.6,
33
+ 'strong': 0.5, 'solid': 0.5, 'robust': 0.5, 'healthy': 0.5,
34
+ 'positive': 0.4, 'optimistic': 0.5, 'bullish': 0.8, 'rally': 0.7,
35
+ 'beat': 0.7, 'exceeded': 0.7, 'outperformed': 0.7, 'success': 0.6,
36
+ 'profit': 0.3, 'earnings': 0.2, 'revenue': 0.2, 'growth': 0.5
37
+ }
38
+
39
+ # Strong bearish keywords with weights
40
+ self.bearish_keywords = {
41
+ 'crashing': 0.9, 'plummeting': 0.9, 'collapsing': 0.9, 'tanking': 0.9,
42
+ 'disaster': 0.8, 'terrible': 0.8, 'awful': 0.8, 'horrible': 0.8,
43
+ 'crisis': 0.7, 'recession': 0.8, 'bankruptcy': 0.9, 'failed': 0.7,
44
+ 'down 10%': 0.8, 'down 15%': 0.9, 'down 20%': 0.9, 'lost 10%': 0.8,
45
+ 'fell 15%': 0.8, 'dropped 20%': 0.9, 'plunged': 0.8, 'tumbled': 0.7,
46
+ 'falling': 0.6, 'declining': 0.6, 'dropping': 0.6, 'losing': 0.6,
47
+ 'weak': 0.5, 'poor': 0.5, 'bad': 0.4, 'negative': 0.4,
48
+ 'bearish': 0.8, 'selloff': 0.7, 'sell-off': 0.7, 'correction': 0.6,
49
+ 'missed': 0.6, 'disappointed': 0.6, 'concerns': 0.4, 'worried': 0.5
50
+ }
51
+
52
+ # Neutral keywords that should reduce extreme predictions
53
+ self.neutral_keywords = {
54
+ 'mixed': 0.7, 'uncertain': 0.6, 'unclear': 0.6, 'sideways': 0.8,
55
+ 'flat': 0.7, 'stable': 0.5, 'unchanged': 0.8, 'waiting': 0.6,
56
+ 'consolidating': 0.7, 'range-bound': 0.8, 'choppy': 0.7
57
+ }
58
+
59
+ def extract_keywords(self, text: str) -> Dict[str, float]:
60
+ """Extract sentiment keywords and their weights from text"""
61
+ text_lower = text.lower()
62
+ found_keywords = {'bullish': [], 'bearish': [], 'neutral': []}
63
+
64
+ # Check for bullish keywords
65
+ for keyword, weight in self.bullish_keywords.items():
66
+ if keyword in text_lower:
67
+ found_keywords['bullish'].append((keyword, weight))
68
+
69
+ # Check for bearish keywords
70
+ for keyword, weight in self.bearish_keywords.items():
71
+ if keyword in text_lower:
72
+ found_keywords['bearish'].append((keyword, weight))
73
+
74
+ # Check for neutral keywords
75
+ for keyword, weight in self.neutral_keywords.items():
76
+ if keyword in text_lower:
77
+ found_keywords['neutral'].append((keyword, weight))
78
+
79
+ return found_keywords
80
+
81
+ def apply_rules(self, text: str, model_probabilities: np.ndarray,
82
+ confidence_threshold: float = 0.7) -> Tuple[np.ndarray, str]:
83
+ """Apply rule-based post-processing to model predictions"""
84
+
85
+ original_probs = model_probabilities.copy()
86
+ adjusted_probs = model_probabilities.copy()
87
+
88
+ # Extract keywords
89
+ keywords = self.extract_keywords(text)
90
+
91
+ # Calculate keyword scores
92
+ bullish_score = sum(weight for _, weight in keywords['bullish'])
93
+ bearish_score = sum(weight for _, weight in keywords['bearish'])
94
+ neutral_score = sum(weight for _, weight in keywords['neutral'])
95
+
96
+ explanation_parts = []
97
+
98
+ # Apply adjustments based on keyword scores
99
+ if bullish_score > 0.5:
100
+ # Boost bullish probability
101
+ boost = min(0.3, bullish_score * 0.2)
102
+ adjusted_probs[2] += boost # Bullish
103
+ adjusted_probs[0] = max(0.05, adjusted_probs[0] - boost/2) # Bearish
104
+ adjusted_probs[1] = max(0.05, adjusted_probs[1] - boost/2) # Neutral
105
+ explanation_parts.append(f"Bullish keywords detected (score: {bullish_score:.2f})")
106
+
107
+ if bearish_score > 0.5:
108
+ # Boost bearish probability
109
+ boost = min(0.3, bearish_score * 0.2)
110
+ adjusted_probs[0] += boost # Bearish
111
+ adjusted_probs[2] = max(0.05, adjusted_probs[2] - boost/2) # Bullish
112
+ adjusted_probs[1] = max(0.05, adjusted_probs[1] - boost/2) # Neutral
113
+ explanation_parts.append(f"Bearish keywords detected (score: {bearish_score:.2f})")
114
+
115
+ if neutral_score > 0.5:
116
+ # Boost neutral probability
117
+ boost = min(0.2, neutral_score * 0.15)
118
+ adjusted_probs[1] += boost # Neutral
119
+ adjusted_probs[0] = max(0.05, adjusted_probs[0] - boost/2) # Bearish
120
+ adjusted_probs[2] = max(0.05, adjusted_probs[2] - boost/2) # Bullish
121
+ explanation_parts.append(f"Neutral keywords detected (score: {neutral_score:.2f})")
122
+
123
+ # Normalize probabilities
124
+ adjusted_probs = adjusted_probs / np.sum(adjusted_probs)
125
+
126
+ # Create explanation
127
+ if explanation_parts:
128
+ explanation = "Applied: " + ", ".join(explanation_parts)
129
+ else:
130
+ explanation = "No significant keywords detected"
131
+
132
+ return adjusted_probs, explanation
133
+
134
+ # Initialize rule engine
135
+ rule_engine = SentimentRuleEngine()
136
+
137
+ class FinancialSentimentEnsemble:
138
+ """Ensemble model for financial sentiment analysis using Hugging Face models"""
139
+
140
+ def __init__(self):
141
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
142
+ self.models = {}
143
+ self.tokenizers = {}
144
+ self.label_names = ["Bearish πŸ“‰", "Neutral βš–οΈ", "Bullish πŸ“ˆ"]
145
+
146
+ # Hugging Face model configurations
147
+ self.model_configs = {
148
+ "distilbert": {
149
+ "name": "DistilBERT (Fast)",
150
+ "repo_id": "codealchemist01/financial-sentiment-distilbert",
151
+ "description": "Fast and efficient model"
152
+ },
153
+ "bert_large": {
154
+ "name": "BERT-Large (Advanced)",
155
+ "repo_id": "codealchemist01/financial-sentiment-bert-large",
156
+ "description": "Most advanced model"
157
+ },
158
+ "improved": {
159
+ "name": "Improved Model",
160
+ "repo_id": "codealchemist01/financial-sentiment-improved",
161
+ "description": "Enhanced model with advanced training"
162
+ }
163
+ }
164
+
165
+ # Ensemble weights for different combinations
166
+ self.ensemble_weights = {
167
+ "smart_ensemble": {"distilbert": 0.3, "bert_large": 0.7},
168
+ "all_models": {"distilbert": 0.2, "improved": 0.3, "bert_large": 0.5}
169
+ }
170
+
171
+ self.load_models()
172
+
173
+ def load_models(self):
174
+ """Load models from Hugging Face Hub"""
175
+ loaded_models = []
176
+
177
+ for model_key, config in self.model_configs.items():
178
+ try:
179
+ logger.info(f"Loading {config['name']} from {config['repo_id']}")
180
+
181
+ tokenizer = AutoTokenizer.from_pretrained(config["repo_id"])
182
+ model = AutoModelForSequenceClassification.from_pretrained(config["repo_id"])
183
+ model.to(self.device)
184
+ model.eval()
185
+
186
+ self.tokenizers[model_key] = tokenizer
187
+ self.models[model_key] = model
188
+ loaded_models.append(config["name"])
189
+
190
+ logger.info(f"βœ… {config['name']} loaded successfully")
191
+
192
+ except Exception as e:
193
+ logger.error(f"❌ Error loading {config['name']}: {e}")
194
+
195
+ logger.info(f"🎯 Total loaded models: {len(loaded_models)}")
196
+ return loaded_models
197
+
198
+ def predict_single_model(self, text, model_key):
199
+ """Get prediction from a single model"""
200
+ if model_key not in self.models:
201
+ return None, f"Model {model_key} not available"
202
+
203
+ try:
204
+ tokenizer = self.tokenizers[model_key]
205
+ model = self.models[model_key]
206
+
207
+ inputs = tokenizer(
208
+ text,
209
+ return_tensors="pt",
210
+ truncation=True,
211
+ padding=True,
212
+ max_length=512
213
+ )
214
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
215
+
216
+ with torch.no_grad():
217
+ outputs = model(**inputs)
218
+ probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
219
+ probabilities = probabilities.cpu().numpy()[0]
220
+
221
+ return probabilities, None
222
+
223
+ except Exception as e:
224
+ return None, f"Error in {model_key}: {str(e)}"
225
+
226
+ def predict_ensemble(self, text, ensemble_type="smart_ensemble", use_rules=True):
227
+ """Make ensemble prediction"""
228
+ if not text.strip():
229
+ return "Please enter some text to analyze.", {}, ""
230
+
231
+ try:
232
+ # Determine which models to use
233
+ if ensemble_type == "smart_ensemble":
234
+ weights = self.ensemble_weights["smart_ensemble"]
235
+ models_to_use = ["distilbert", "bert_large"]
236
+ elif ensemble_type == "all_models":
237
+ weights = self.ensemble_weights["all_models"]
238
+ models_to_use = ["distilbert", "improved", "bert_large"]
239
+ else:
240
+ # Single model prediction
241
+ models_to_use = [ensemble_type]
242
+ weights = {ensemble_type: 1.0}
243
+
244
+ # Get predictions from each model
245
+ ensemble_probabilities = np.zeros(3)
246
+ total_weight = 0
247
+ model_predictions = {}
248
+ model_details = []
249
+
250
+ for model_key in models_to_use:
251
+ if model_key in self.models:
252
+ probabilities, error = self.predict_single_model(text, model_key)
253
+ if probabilities is not None:
254
+ weight = weights.get(model_key, 1.0)
255
+ ensemble_probabilities += probabilities * weight
256
+ total_weight += weight
257
+
258
+ # Store individual results
259
+ predicted_class = np.argmax(probabilities)
260
+ confidence = probabilities[predicted_class]
261
+ model_predictions[model_key] = {
262
+ "prediction": self.label_names[predicted_class],
263
+ "confidence": float(confidence),
264
+ "probabilities": probabilities.tolist()
265
+ }
266
+
267
+ model_details.append(
268
+ f"**{self.model_configs[model_key]['name']}:** "
269
+ f"{self.label_names[predicted_class]} ({confidence:.2%})"
270
+ )
271
+
272
+ if total_weight == 0:
273
+ return "No models available for prediction.", {}, ""
274
+
275
+ # Normalize ensemble probabilities
276
+ ensemble_probabilities = ensemble_probabilities / total_weight
277
+
278
+ # Apply rule-based post-processing if enabled
279
+ rule_explanation = ""
280
+ if use_rules:
281
+ ensemble_probabilities, rule_explanation = rule_engine.apply_rules(
282
+ text, ensemble_probabilities, confidence_threshold=0.7
283
+ )
284
+
285
+ # Get final prediction
286
+ predicted_class = np.argmax(ensemble_probabilities)
287
+ confidence = ensemble_probabilities[predicted_class]
288
+
289
+ # Create detailed results
290
+ if len(models_to_use) > 1:
291
+ result_text = f"**🎯 Ensemble Prediction:** {self.label_names[predicted_class]}\\n"
292
+ result_text += f"**πŸ”₯ Ensemble Confidence:** {confidence:.2%}\\n\\n"
293
+
294
+ result_text += "**πŸ€– Individual Model Results:**\\n"
295
+ for detail in model_details:
296
+ result_text += f"- {detail}\\n"
297
+ result_text += "\\n"
298
+ else:
299
+ result_text = f"**🎯 Prediction:** {self.label_names[predicted_class]}\\n"
300
+ result_text += f"**πŸ”₯ Confidence:** {confidence:.2%}\\n\\n"
301
+
302
+ # Show rule engine effects if applied
303
+ if use_rules and rule_explanation:
304
+ result_text += f"**πŸ€– Rule Engine:** {rule_explanation}\\n\\n"
305
+
306
+ result_text += "**πŸ“Š Final Probabilities:**\\n"
307
+
308
+ # Create probability dictionary for gradio
309
+ prob_dict = {}
310
+ for i, (label, prob) in enumerate(zip(self.label_names, ensemble_probabilities)):
311
+ prob_dict[label] = float(prob)
312
+ result_text += f"- {label}: {prob:.2%}\\n"
313
+
314
+ # Create model comparison details
315
+ comparison_details = ""
316
+ if len(model_predictions) > 1:
317
+ comparison_details = "**πŸ” Model Comparison:**\\n"
318
+ for model_key, pred_data in model_predictions.items():
319
+ comparison_details += f"\\n**{self.model_configs[model_key]['name']}:**\\n"
320
+ for i, (label, prob) in enumerate(zip(self.label_names, pred_data['probabilities'])):
321
+ comparison_details += f" - {label}: {prob:.2%}\\n"
322
+
323
+ return result_text, prob_dict, comparison_details
324
+
325
+ except Exception as e:
326
+ logger.error(f"Prediction error: {e}")
327
+ return f"Error during prediction: {str(e)}", {}, ""
328
+
329
+ # Initialize ensemble model
330
+ try:
331
+ ensemble = FinancialSentimentEnsemble()
332
+ available_models = list(ensemble.models.keys())
333
+ gpu_info = f"πŸš€ **Models loaded:** {len(available_models)} models on {ensemble.device}"
334
+ except Exception as e:
335
+ gpu_info = f"❌ **Error loading models:** {str(e)}"
336
+ ensemble = None
337
+ available_models = []
338
+
339
+ def analyze_sentiment(text, model_selection, use_rules):
340
+ """Main analysis function"""
341
+ if ensemble is None:
342
+ return "Models not loaded. Please check the error above.", {}, ""
343
+
344
+ return ensemble.predict_ensemble(text, model_selection, use_rules)
345
+
346
+ # Example texts for testing
347
+ examples = [
348
+ ["Tesla stock is soaring after excellent Q3 earnings report! πŸš€", "smart_ensemble", True],
349
+ ["The market is showing mixed signals today, uncertain direction.", "smart_ensemble", True],
350
+ ["Major selloff expected as inflation concerns grow. Bearish outlook.", "all_models", True],
351
+ ["Apple announces new iPhone with revolutionary features!", "distilbert", False],
352
+ ["Economic indicators suggest potential recession ahead.", "bert_large", True],
353
+ ["Crypto market rebounds strongly after recent dip.", "smart_ensemble", True]
354
+ ]
355
+
356
+ # Create Gradio interface
357
+ with gr.Blocks(
358
+ title="Financial Sentiment Analysis - Ensemble System",
359
+ theme=gr.themes.Soft(),
360
+ css="""
361
+ .gradio-container {
362
+ max-width: 1000px !important;
363
+ margin: auto !important;
364
+ }
365
+ .header {
366
+ text-align: center;
367
+ padding: 20px;
368
+ background: linear-gradient(90deg, #667eea 0%, #764ba2 100%);
369
+ color: white;
370
+ border-radius: 10px;
371
+ margin-bottom: 20px;
372
+ }
373
+ .model-info {
374
+ background-color: #f8f9fa;
375
+ padding: 15px;
376
+ border-radius: 8px;
377
+ margin: 10px 0;
378
+ }
379
+ """
380
+ ) as demo:
381
+
382
+ gr.HTML(f"""
383
+ <div class="header">
384
+ <h1>πŸ“ˆ Financial Sentiment Analysis Ensemble</h1>
385
+ <h3>Advanced AI-powered sentiment analysis for financial texts using an ensemble of 3 fine-tuned models</h3>
386
+ <p>{gpu_info}</p>
387
+ </div>
388
+ """)
389
+
390
+ with gr.Row():
391
+ with gr.Column(scale=2):
392
+ text_input = gr.Textbox(
393
+ label="πŸ“ Enter Financial Text to Analyze",
394
+ placeholder="Enter financial news, tweets, or market commentary...",
395
+ lines=4
396
+ )
397
+
398
+ with gr.Row():
399
+ model_selection = gr.Dropdown(
400
+ choices=[
401
+ ("🧠 Smart Ensemble (Recommended)", "smart_ensemble"),
402
+ ("🎯 All Models Ensemble", "all_models"),
403
+ ("⚑ DistilBERT (Fast)", "distilbert"),
404
+ ("πŸ”₯ BERT-Large (Advanced)", "bert_large"),
405
+ ("πŸš€ Improved Model", "improved")
406
+ ],
407
+ value="smart_ensemble",
408
+ label="πŸ€– Model Selection"
409
+ )
410
+
411
+ use_rules = gr.Checkbox(
412
+ label="πŸ€– Rule-Based Enhancement",
413
+ value=True,
414
+ info="Apply keyword-based post-processing"
415
+ )
416
+
417
+ analyze_btn = gr.Button("πŸ” Analyze Sentiment", variant="primary", size="lg")
418
+
419
+ with gr.Column(scale=2):
420
+ result_output = gr.Textbox(
421
+ label="πŸ“Š Analysis Results",
422
+ lines=12,
423
+ interactive=False
424
+ )
425
+
426
+ prob_output = gr.Label(
427
+ label="πŸ“ˆ Probability Distribution",
428
+ num_top_classes=3
429
+ )
430
+
431
+ with gr.Row():
432
+ comparison_output = gr.Textbox(
433
+ label="πŸ” Model Comparison Details",
434
+ lines=8,
435
+ interactive=False,
436
+ visible=True
437
+ )
438
+
439
+ # Event handlers
440
+ analyze_btn.click(
441
+ fn=analyze_sentiment,
442
+ inputs=[text_input, model_selection, use_rules],
443
+ outputs=[result_output, prob_output, comparison_output]
444
+ )
445
+
446
+ # Examples section
447
+ gr.Examples(
448
+ examples=examples,
449
+ inputs=[text_input, model_selection, use_rules],
450
+ outputs=[result_output, prob_output, comparison_output],
451
+ fn=analyze_sentiment,
452
+ cache_examples=False,
453
+ label="πŸ’‘ Try these examples:"
454
+ )
455
+
456
+ # Model information
457
+ gr.HTML("""
458
+ <div class="model-info">
459
+ <h4>πŸ€– Ensemble System Information</h4>
460
+ <ul>
461
+ <li><strong>🧠 Smart Ensemble:</strong> DistilBERT + BERT-Large (Best balance of speed and accuracy)</li>
462
+ <li><strong>🎯 All Models:</strong> DistilBERT + Improved + BERT-Large (Maximum consensus)</li>
463
+ <li><strong>⚑ DistilBERT:</strong> Fast and efficient model optimized for real-time analysis</li>
464
+ <li><strong>πŸ”₯ BERT-Large:</strong> Most advanced model with deep contextual understanding</li>
465
+ <li><strong>πŸš€ Improved Model:</strong> Enhanced with advanced training techniques</li>
466
+ </ul>
467
+ <p><em>πŸ’‘ Tip: Smart Ensemble provides the best balance of accuracy and performance!</em></p>
468
+ <p><em>πŸ€– Rule Engine: Applies keyword-based post-processing to improve accuracy on financial texts</em></p>
469
+ </div>
470
+ """)
471
+
472
+ if __name__ == "__main__":
473
+ demo.launch()