codealchemist01 commited on
Commit
1c206bc
Β·
verified Β·
1 Parent(s): 333d5f2

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +234 -0
app.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Hugging Face Space App for Financial Sentiment Analysis Ensemble
4
+ """
5
+
6
+ import gradio as gr
7
+ import torch
8
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
9
+ import numpy as np
10
+ from datetime import datetime
11
+ import json
12
+
13
+ class FinancialSentimentEnsemble:
14
+ def __init__(self):
15
+ self.models = {}
16
+ self.tokenizers = {}
17
+ self.model_names = [
18
+ "codealchemist01/financial-sentiment-distilbert",
19
+ "codealchemist01/financial-sentiment-bert-large",
20
+ "codealchemist01/financial-sentiment-improved"
21
+ ]
22
+ self.labels = ["Bearish πŸ“‰", "Neutral ➑️", "Bullish πŸ“ˆ"]
23
+ self.load_models()
24
+
25
+ def load_models(self):
26
+ """Load all models and tokenizers"""
27
+ print("πŸš€ Loading Financial Sentiment Analysis Ensemble...")
28
+
29
+ for i, model_name in enumerate(self.model_names):
30
+ try:
31
+ print(f"πŸ“₯ Loading {model_name}...")
32
+ self.tokenizers[i] = AutoTokenizer.from_pretrained(model_name)
33
+ self.models[i] = AutoModelForSequenceClassification.from_pretrained(model_name)
34
+ self.models[i].eval()
35
+ print(f"βœ… {model_name} loaded successfully!")
36
+ except Exception as e:
37
+ print(f"❌ Error loading {model_name}: {e}")
38
+
39
+ print(f"πŸŽ‰ Ensemble ready with {len(self.models)} models!")
40
+
41
+ def predict_single_model(self, text, model_idx):
42
+ """Predict sentiment using a single model"""
43
+ if model_idx not in self.models:
44
+ return None
45
+
46
+ try:
47
+ inputs = self.tokenizers[model_idx](
48
+ text,
49
+ return_tensors="pt",
50
+ truncation=True,
51
+ padding=True,
52
+ max_length=512
53
+ )
54
+
55
+ with torch.no_grad():
56
+ outputs = self.models[model_idx](**inputs)
57
+ probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
58
+
59
+ return probabilities[0].numpy()
60
+ except Exception as e:
61
+ print(f"Error in model {model_idx}: {e}")
62
+ return None
63
+
64
+ def predict_ensemble(self, text):
65
+ """Predict sentiment using ensemble of all models"""
66
+ if not text.strip():
67
+ return "Please enter some text to analyze.", {}, {}
68
+
69
+ individual_predictions = {}
70
+ all_probabilities = []
71
+
72
+ # Get predictions from each model
73
+ for i, model_name in enumerate(self.model_names):
74
+ probs = self.predict_single_model(text, i)
75
+ if probs is not None:
76
+ all_probabilities.append(probs)
77
+
78
+ # Individual model results
79
+ predicted_class = np.argmax(probs)
80
+ confidence = probs[predicted_class]
81
+
82
+ model_short_name = model_name.split("/")[-1].replace("financial-sentiment-", "").title()
83
+ individual_predictions[f"{model_short_name}"] = {
84
+ "Prediction": self.labels[predicted_class],
85
+ "Confidence": f"{confidence:.1%}"
86
+ }
87
+
88
+ if not all_probabilities:
89
+ return "Error: No models available for prediction.", {}, {}
90
+
91
+ # Ensemble prediction (average probabilities)
92
+ ensemble_probs = np.mean(all_probabilities, axis=0)
93
+ ensemble_prediction = np.argmax(ensemble_probs)
94
+ ensemble_confidence = ensemble_probs[ensemble_prediction]
95
+
96
+ # Create probability distribution for visualization
97
+ prob_dict = {}
98
+ for i, label in enumerate(self.labels):
99
+ prob_dict[label] = float(ensemble_probs[i])
100
+
101
+ # Result summary
102
+ result_text = f"""
103
+ ## 🎯 Ensemble Prediction: **{self.labels[ensemble_prediction]}**
104
+ **Confidence:** {ensemble_confidence:.1%}
105
+
106
+ ### πŸ“Š Probability Distribution:
107
+ - πŸ“‰ Bearish: {ensemble_probs[0]:.1%}
108
+ - ➑️ Neutral: {ensemble_probs[1]:.1%}
109
+ - πŸ“ˆ Bullish: {ensemble_probs[2]:.1%}
110
+
111
+ ### πŸ€– Individual Model Results:
112
+ """
113
+
114
+ for model_name, result in individual_predictions.items():
115
+ result_text += f"- **{model_name}**: {result['Prediction']} ({result['Confidence']})\n"
116
+
117
+ return result_text, prob_dict, individual_predictions
118
+
119
+ # Initialize the ensemble
120
+ ensemble = FinancialSentimentEnsemble()
121
+
122
+ def analyze_sentiment(text):
123
+ """Main function for Gradio interface"""
124
+ return ensemble.predict_ensemble(text)
125
+
126
+ # Example texts for demonstration
127
+ examples = [
128
+ "The stock market is showing strong bullish momentum with record highs across major indices.",
129
+ "Company earnings fell short of expectations, leading to a significant drop in share price.",
130
+ "The Federal Reserve maintained interest rates, keeping market conditions stable.",
131
+ "Tesla's innovative battery technology could revolutionize the automotive industry.",
132
+ "Rising inflation concerns are creating uncertainty in the financial markets.",
133
+ "The merger announcement sent both companies' stock prices soaring.",
134
+ "Quarterly results were mixed, with some sectors outperforming while others lagged."
135
+ ]
136
+
137
+ # Create Gradio interface
138
+ with gr.Blocks(
139
+ theme=gr.themes.Soft(),
140
+ title="Financial Sentiment Analysis Ensemble",
141
+ css="""
142
+ .gradio-container {
143
+ max-width: 1200px !important;
144
+ }
145
+ .main-header {
146
+ text-align: center;
147
+ margin-bottom: 2rem;
148
+ }
149
+ .model-info {
150
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
151
+ color: white;
152
+ padding: 1rem;
153
+ border-radius: 10px;
154
+ margin: 1rem 0;
155
+ }
156
+ """
157
+ ) as demo:
158
+
159
+ gr.HTML("""
160
+ <div class="main-header">
161
+ <h1>πŸš€ Financial Sentiment Analysis Ensemble</h1>
162
+ <p>Advanced AI-powered sentiment analysis for financial texts using an ensemble of 3 fine-tuned models</p>
163
+ </div>
164
+ """)
165
+
166
+ with gr.Row():
167
+ with gr.Column(scale=2):
168
+ text_input = gr.Textbox(
169
+ label="πŸ“ Enter Financial Text",
170
+ placeholder="Type or paste financial news, social media posts, or market commentary here...",
171
+ lines=4,
172
+ max_lines=10
173
+ )
174
+
175
+ analyze_btn = gr.Button("πŸ” Analyze Sentiment", variant="primary", size="lg")
176
+
177
+ gr.Examples(
178
+ examples=examples,
179
+ inputs=text_input,
180
+ label="πŸ’‘ Try these examples:"
181
+ )
182
+
183
+ with gr.Column(scale=3):
184
+ result_output = gr.Markdown(label="πŸ“Š Analysis Results")
185
+
186
+ with gr.Row():
187
+ prob_plot = gr.BarPlot(
188
+ x="Sentiment",
189
+ y="Probability",
190
+ title="Ensemble Probability Distribution",
191
+ x_title="Sentiment Categories",
192
+ y_title="Probability",
193
+ width=400,
194
+ height=300
195
+ )
196
+
197
+ individual_results = gr.JSON(
198
+ label="πŸ€– Individual Model Predictions",
199
+ visible=True
200
+ )
201
+
202
+ # Model Information
203
+ gr.HTML("""
204
+ <div class="model-info">
205
+ <h3>🧠 Ensemble Models:</h3>
206
+ <ul>
207
+ <li><strong>DistilBERT Model:</strong> Fast and efficient, optimized for real-time analysis</li>
208
+ <li><strong>BERT-Large Model:</strong> High accuracy with deep contextual understanding</li>
209
+ <li><strong>Improved Model:</strong> Enhanced with advanced training techniques</li>
210
+ </ul>
211
+ <p><strong>Ensemble Accuracy:</strong> 79.7% | <strong>Categories:</strong> Bearish πŸ“‰, Neutral ➑️, Bullish πŸ“ˆ</p>
212
+ </div>
213
+ """)
214
+
215
+ # Event handlers
216
+ analyze_btn.click(
217
+ fn=analyze_sentiment,
218
+ inputs=text_input,
219
+ outputs=[result_output, prob_plot, individual_results]
220
+ )
221
+
222
+ text_input.submit(
223
+ fn=analyze_sentiment,
224
+ inputs=text_input,
225
+ outputs=[result_output, prob_plot, individual_results]
226
+ )
227
+
228
+ # Launch the app
229
+ if __name__ == "__main__":
230
+ demo.launch(
231
+ server_name="0.0.0.0",
232
+ server_port=7860,
233
+ share=False
234
+ )