MahekTrivedi commited on
Commit
aacc09d
·
verified ·
1 Parent(s): 65f8501

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +410 -12
src/streamlit_app.py CHANGED
@@ -1,9 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  import pandas as pd
3
  import numpy as np
4
  import torch
 
 
5
  from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForCausalLM
6
- from peft import PeftModel
7
  from sklearn.linear_model import LinearRegression
8
  from sklearn.model_selection import train_test_split
9
  from sklearn.metrics import mean_squared_error, r2_score
@@ -16,23 +403,34 @@ st.set_page_config(layout="wide", page_title="FinGPT Investment Predictor")
16
 
17
  # Use st.cache_resource to load heavy models only once
18
  @st.cache_resource
19
- def load_sentiment_model(sentiment_model_name="FinGPT/fingpt-sentiment_llama2-13b_lora",
20
- base_tokenizer_name="meta-llama/Llama-2-13b-chat-hf"): # Added base_tokenizer_name
21
  """
22
- Loads the pre-trained sentiment analysis model and tokenizer.
23
  Uses st.cache_resource to prevent reloading on every Streamlit rerun.
24
  """
25
- st.write(f"Loading sentiment tokenizer from base model: {base_tokenizer_name}...")
26
- tokenizer = AutoTokenizer.from_pretrained(base_tokenizer_name) # Load tokenizer from base Llama model
 
 
27
 
28
- st.write(f"Loading sentiment model: {sentiment_model_name}...")
29
- # Load the sentiment model (which is the LoRA adapter)
30
- model = AutoModelForSequenceClassification.from_pretrained(sentiment_model_name)
31
- model.eval() # Set model to evaluation mode
 
 
 
 
 
 
 
 
32
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
- model.to(device)
34
  st.write(f"Sentiment model loaded. Using device: {device}")
35
- return tokenizer, model, device
 
36
 
37
  tokenizer, sentiment_model, device = load_sentiment_model()
38
 
 
1
+ # import streamlit as st
2
+ # import pandas as pd
3
+ # import numpy as np
4
+ # import torch
5
+ # from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForCausalLM
6
+ # from peft import PeftModel
7
+ # from sklearn.linear_model import LinearRegression
8
+ # from sklearn.model_selection import train_test_split
9
+ # from sklearn.metrics import mean_squared_error, r2_score
10
+ # import matplotlib.pyplot as plt
11
+
12
+ # # Set page configuration
13
+ # st.set_page_config(layout="wide", page_title="FinGPT Investment Predictor")
14
+
15
+ # # --- Part 1: Financial Sentiment Analysis Model Loading and Function ---
16
+
17
+ # # Use st.cache_resource to load heavy models only once
18
+ # @st.cache_resource
19
+ # def load_sentiment_model(sentiment_model_name="FinGPT/fingpt-sentiment_llama2-13b_lora",
20
+ # base_tokenizer_name="meta-llama/Llama-2-13b-chat-hf"): # Added base_tokenizer_name
21
+ # """
22
+ # Loads the pre-trained sentiment analysis model and tokenizer.
23
+ # Uses st.cache_resource to prevent reloading on every Streamlit rerun.
24
+ # """
25
+ # st.write(f"Loading sentiment tokenizer from base model: {base_tokenizer_name}...")
26
+ # tokenizer = AutoTokenizer.from_pretrained(base_tokenizer_name) # Load tokenizer from base Llama model
27
+
28
+ # st.write(f"Loading sentiment model: {sentiment_model_name}...")
29
+ # # Load the sentiment model (which is the LoRA adapter)
30
+ # model = AutoModelForSequenceClassification.from_pretrained(sentiment_model_name)
31
+ # model.eval() # Set model to evaluation mode
32
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
+ # model.to(device)
34
+ # st.write(f"Sentiment model loaded. Using device: {device}")
35
+ # return tokenizer, model, device
36
+
37
+ # tokenizer, sentiment_model, device = load_sentiment_model()
38
+
39
+ # def get_sentiment_score_and_label(text):
40
+ # """
41
+ # Analyzes the sentiment of the given text using the loaded model.
42
+ # Returns a numerical score (-1 to 1) and a categorical label (negative/neutral/positive).
43
+ # """
44
+ # inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
45
+ # with torch.no_grad():
46
+ # outputs = sentiment_model(**inputs)
47
+ # logits = outputs.logits
48
+
49
+ # probabilities = torch.softmax(logits, dim=1).cpu().numpy()[0]
50
+
51
+ # # Assuming the FinGPT sentiment model outputs logits in the order: [negative, neutral, positive].
52
+ # neg_score = probabilities[0]
53
+ # neu_score = probabilities[1]
54
+ # pos_score = probabilities[2]
55
+
56
+ # # A simple weighted average to get a single sentiment score between -1 and 1
57
+ # sentiment_score = (pos_score * 1) + (neg_score * -1) + (neu_score * 0)
58
+
59
+ # labels = ["negative", "neutral", "positive"]
60
+ # predicted_class_id = logits.argmax().item()
61
+ # predicted_label = labels[predicted_class_id]
62
+
63
+ # return sentiment_score, predicted_label
64
+
65
+ # # --- Part 2: Simulate Financial Data and Train Prediction Model ---
66
+
67
+ # # Use st.cache_data to run data generation and model training only once
68
+ # @st.cache_data
69
+ # def prepare_data_and_train_model():
70
+ # """
71
+ # Simulates historical financial data, calculates sentiment,
72
+ # and trains a simple linear regression model.
73
+ # """
74
+ # st.write("Simulating financial data and training prediction model...")
75
+ # # Let's create some dummy historical data for a stock and news headlines
76
+ # dates = pd.to_datetime(pd.date_range(start='2024-01-01', periods=60, freq='D')) # Increased period for better visualization
77
+ # np.random.seed(42) # for reproducibility
78
+
79
+ # # Simulate stock prices with some trend and noise
80
+ # base_price = 100
81
+ # prices = [base_price + np.random.uniform(-2, 2)]
82
+ # for _ in range(1, len(dates)):
83
+ # change = np.random.uniform(-1, 1) + (np.random.uniform(-0.5, 0.5) * np.random.choice([-1, 1], p=[0.2, 0.8]))
84
+ # prices.append(prices[-1] + change)
85
+ # prices = np.array(prices) + np.cumsum(np.random.uniform(-0.1, 0.3, len(dates)))
86
+
87
+ # # Simulate financial news headlines for each day
88
+ # dummy_news = [
89
+ # "Tech company reports strong Q4 earnings, stock up.",
90
+ # "Market shows signs of recovery after recent dip.",
91
+ # "Government announces new regulations affecting energy sector.",
92
+ # "Company X faces legal challenges, shares fall.",
93
+ # "Positive economic indicators boost investor confidence.",
94
+ # "Global supply chain issues continue to impact manufacturing.",
95
+ # "Innovation in AI drives new growth opportunities for company Z.",
96
+ # "Analyst downgrades stock, citing valuation concerns.",
97
+ # "Central bank holds interest rates steady, as expected.",
98
+ # "Surprise acquisition announced, boosting stock of target company.",
99
+ # "Quarterly sales below expectations, cautious outlook given.",
100
+ # "New product launch receives mixed reviews.",
101
+ # "Commodity prices stabilizing after volatile period.",
102
+ # "Major competitor announces expansion plans.",
103
+ # "Strong consumer spending data released.",
104
+ # "Company Y CEO resigns amidst controversy.",
105
+ # "Healthcare sector sees increased M&A activity.",
106
+ # "Inflation concerns persist, impacting consumer sentiment.",
107
+ # "Renewable energy stocks gain traction.",
108
+ # "Geopolitical tensions rise, market volatility increases.",
109
+ # "Positive clinical trial results for biotech firm.",
110
+ # "Earnings report beats estimates, stock rallies.",
111
+ # "Product recall announced, shares decline sharply.",
112
+ # "New trade agreement expected to benefit exporters.",
113
+ # "Tech giant invests heavily in R&D, future growth anticipated.",
114
+ # "Retail sales unexpectedly weak last month.",
115
+ # "Financial institution expands services, positive outlook.",
116
+ # "Mining company faces environmental penalties.",
117
+ # "Market sentiment remains cautiously optimistic.",
118
+ # "Dividend increase announced, attracting income investors.",
119
+ # "Breakthrough in medical research boosts pharma stock.",
120
+ # "New energy policy to impact utility companies.",
121
+ # "Cybersecurity firm reports major data breach.",
122
+ # "E-commerce sales exceed forecasts during holiday season.",
123
+ # "Automotive industry faces chip shortage challenges.",
124
+ # "Biotech startup secures significant funding.",
125
+ # "Real estate market shows signs of cooling.",
126
+ # "Investment bank upgrades rating for tech stock.",
127
+ # "Consumer confidence index reaches new high.",
128
+ # "Airline industry recovers strongly post-pandemic.",
129
+ # "Retailer announces store closures, stock drops.",
130
+ # "Software company acquires competitor, market reacts positively.",
131
+ # "Global trade tensions ease, positive for exporters.",
132
+ # "New environmental regulations impact manufacturing costs.",
133
+ # "Tourism sector sees strong rebound, hotel stocks rise.",
134
+ # "Bank reports solid profits despite economic headwinds.",
135
+ # "Construction firm wins major government contract.",
136
+ # "Food prices continue to rise, affecting grocery chains.",
137
+ # "Luxury brand sales surge in emerging markets.",
138
+ # "Telecommunications giant invests in 5G infrastructure."
139
+ # ]
140
+
141
+ # # Ensure we have enough news for all dates
142
+ # if len(dummy_news) < len(dates):
143
+ # dummy_news_extended = (dummy_news * (len(dates) // len(dummy_news) + 1))[:len(dates)]
144
+ # else:
145
+ # dummy_news_extended = dummy_news[:len(dates)]
146
+
147
+ # # Calculate sentiment scores for each day's news
148
+ # sentiment_scores = [get_sentiment_score_and_label(news)[0] for news in dummy_news_extended]
149
+
150
+ # # Create a DataFrame
151
+ # data = pd.DataFrame({
152
+ # 'Date': dates,
153
+ # 'Price': prices,
154
+ # 'Sentiment': sentiment_scores,
155
+ # 'News': dummy_news_extended
156
+ # })
157
+ # data.set_index('Date', inplace=True)
158
+
159
+ # # Add lagged price and sentiment as features
160
+ # data['Previous_Day_Price'] = data['Price'].shift(1)
161
+ # data['Previous_Day_Sentiment'] = data['Sentiment'].shift(1)
162
+
163
+ # # Drop the first row which will have NaN due to shifting
164
+ # data.dropna(inplace=True)
165
+
166
+ # # We want to predict 'Price' based on 'Previous_Day_Price' and 'Previous_Day_Sentiment'
167
+ # X = data[['Previous_Day_Price', 'Previous_Day_Sentiment']]
168
+ # y = data['Price']
169
+
170
+ # # Split data into training and testing sets (chronologically)
171
+ # test_size_ratio = 0.2
172
+ # split_index = int(len(data) * (1 - test_size_ratio))
173
+ # X_train, X_test = X.iloc[:split_index], X.iloc[split_index:]
174
+ # y_train, y_test = y.iloc[:split_index], y.iloc[split_index:]
175
+
176
+ # # Initialize and train the Linear Regression model
177
+ # model_prediction = LinearRegression()
178
+ # model_prediction.fit(X_train, y_train)
179
+
180
+ # # Make predictions on the test set
181
+ # y_pred = model_prediction.predict(X_test)
182
+
183
+ # # Evaluate the model
184
+ # mse = mean_squared_error(y_test, y_pred)
185
+ # r2 = r2_score(y_test, y_pred)
186
+
187
+ # return data, X_train, X_test, y_train, y_test, y_pred, model_prediction, mse, r2
188
+
189
+ # # Run the data preparation and model training
190
+ # data, X_train, X_test, y_train, y_test, y_pred, model_prediction, mse, r2 = prepare_data_and_train_model()
191
+
192
+ # # --- Part 3: LLM Forecaster Model Loading and Function (Conceptual) ---
193
+
194
+ # # Load the base Llama-2-7b-chat model (required by the LoRA adapter)
195
+ # @st.cache_resource
196
+ # def load_base_llm_model(base_model_name="meta-llama/Llama-2-7b-chat-hf"):
197
+ # """
198
+ # Loads the base Large Language Model required for the FinGPT forecaster.
199
+ # """
200
+ # st.write(f"Loading base LLM model: {base_model_name}...")
201
+ # base_tokenizer = AutoTokenizer.from_pretrained(base_model_name)
202
+ # base_model = AutoModelForCausalLM.from_pretrained(
203
+ # base_model_name,
204
+ # torch_dtype=torch.float16, # Use float16 for memory efficiency
205
+ # device_map="auto" # Automatically maps model to available devices (GPU if available)
206
+ # )
207
+ # return base_tokenizer, base_model
208
+
209
+ # base_tokenizer, base_llm_model = load_base_llm_model()
210
+
211
+ # @st.cache_resource
212
+ # def load_forecaster_model(forecaster_model_name="FinGPT/fingpt-forecaster_dow30_llama2-7b_lora"):
213
+ # """
214
+ # Loads the FinGPT forecaster LoRA adapter and merges it with the base LLM.
215
+ # """
216
+ # st.write(f"Loading FinGPT forecaster model: {forecaster_model_name}...")
217
+ # # Load the LoRA adapter
218
+ # model = PeftModel.from_pretrained(base_llm_model, forecaster_model_name)
219
+ # model = model.eval()
220
+ # st.write(f"FinGPT forecaster model loaded.")
221
+ # return model
222
+
223
+ # forecaster_llm_model = load_forecaster_model()
224
+
225
+ # def get_llm_forecast(ticker, current_date, news_summary, current_price):
226
+ # """
227
+ # Generates a text-based forecast using the FinGPT forecaster LLM.
228
+ # This is a conceptual demonstration of LLM forecasting.
229
+ # """
230
+ # # Construct a prompt for the LLM forecaster
231
+ # # This prompt structure is simplified; real FinGPT forecasters might expect more complex inputs
232
+ # # based on their training data (e.g., historical data, financial statements).
233
+ # prompt = f"""
234
+ # You are a financial expert.
235
+ # Analyze the following information and provide a concise prediction for the stock price movement of {ticker} for the next week.
236
+
237
+ # Current Date: {current_date.strftime('%Y-%m-%d')}
238
+ # Current Price of {ticker}: ${current_price:.2f}
239
+ # Recent News: "{news_summary}"
240
+
241
+ # Based on this information, what is your prediction for {ticker}'s stock price movement next week? Provide a brief analysis.
242
+ # """
243
+
244
+ # # For Llama-2-chat models, typically prompts are wrapped with specific tokens
245
+ # # This is a common format for instruction-tuned Llama models.
246
+ # chat_template = f"<s>[INST] {prompt} [/INST]"
247
+
248
+ # inputs = base_tokenizer(chat_template, return_tensors="pt").to(device)
249
+
250
+ # # Generate response from the LLM
251
+ # with torch.no_grad():
252
+ # outputs = forecaster_llm_model.generate(
253
+ # **inputs,
254
+ # max_new_tokens=200, # Limit the length of the generated response
255
+ # num_return_sequences=1,
256
+ # do_sample=True,
257
+ # top_k=50,
258
+ # top_p=0.95,
259
+ # temperature=0.7,
260
+ # )
261
+
262
+ # # Decode the generated text, removing the prompt part
263
+ # response_text = base_tokenizer.decode(outputs[0], skip_special_tokens=True)
264
+ # # The response will include the prompt. We need to find the actual generated text.
265
+ # # This parsing might need to be more robust depending on the model's exact output format.
266
+ # generated_forecast = response_text.split("[/INST]")[-1].strip()
267
+
268
+ # return generated_forecast
269
+
270
+
271
+ # # --- Streamlit UI Layout ---
272
+
273
+ # st.title("FinGPT-Powered Investment Predictor 📈")
274
+
275
+ # st.markdown("""
276
+ # This application demonstrates how financial news sentiment can be integrated with historical price data
277
+ # to make simple investment predictions. It uses a pre-trained FinBERT-like model for sentiment analysis
278
+ # and a Linear Regression model for price prediction.
279
+
280
+ # **Disclaimer:** This is a simplified demonstration for educational purposes only.
281
+ # It should **NOT** be used for actual investment decisions. Stock market prediction is highly complex
282
+ # and involves many more factors and sophisticated models.
283
+ # """)
284
+
285
+ # # Create tabs for better organization
286
+ # tab1, tab2, tab3, tab4 = st.tabs(["Sentiment Analyzer", "Historical Data & Model Performance", "Predict Tomorrow's Price (Linear Regression)", "LLM Forecaster (Conceptual)"])
287
+
288
+ # with tab1:
289
+ # st.header("Financial News Sentiment Analyzer")
290
+ # st.write("Enter a financial news headline or text to get its sentiment.")
291
+
292
+ # news_input = st.text_area("News Text:", height=150, placeholder="e.g., 'Apple's stock surged after reporting record-breaking earnings.'")
293
+
294
+ # if st.button("Analyze Sentiment"):
295
+ # if news_input:
296
+ # sentiment_score, sentiment_label = get_sentiment_score_and_label(news_input)
297
+ # st.markdown(f"**Sentiment Score:** `{sentiment_score:.3f}` (closer to 1 is positive, -1 is negative)")
298
+ # st.markdown(f"**Sentiment Label:** `{sentiment_label.upper()}`")
299
+ # else:
300
+ # st.warning("Please enter some text to analyze sentiment.")
301
+
302
+ # with tab2:
303
+ # st.header("Simulated Historical Data and Model Performance")
304
+ # st.write("Here's a look at the simulated historical data used and how the prediction model performed on the test set.")
305
+
306
+ # st.subheader("Sample Historical Data")
307
+ # st.dataframe(data.head())
308
+
309
+ # st.subheader("Prediction Model Performance")
310
+ # st.write(f"**Mean Squared Error (MSE):** `{mse:.2f}`")
311
+ # st.write(f"**R-squared (R2):** `{r2:.2f}`")
312
+ # st.info("A lower MSE indicates better prediction accuracy, and an R2 closer to 1 indicates that the model explains more of the variance in the target variable.")
313
+
314
+ # st.subheader("Actual vs. Predicted Prices (Test Set)")
315
+ # fig, ax = plt.subplots(figsize=(12, 6))
316
+ # ax.plot(y_test.index, y_test, label='Actual Price', marker='o', linestyle='-', markersize=4)
317
+ # ax.plot(y_test.index, y_pred, label='Predicted Price', marker='x', linestyle='--', markersize=4)
318
+ # ax.set_title('Stock Price Prediction with Sentiment (Simulated Data)')
319
+ # ax.set_xlabel('Date')
320
+ # ax.set_ylabel('Price')
321
+ # ax.legend()
322
+ # ax.grid(True)
323
+ # st.pyplot(fig)
324
+
325
+ # with tab3:
326
+ # st.header("Predict Tomorrow's Price (Linear Regression)")
327
+ # st.write("Enter today's closing price and relevant news to get a conceptual prediction for tomorrow using a Linear Regression model.")
328
+
329
+ # last_known_price_simulated = data['Price'].iloc[-1]
330
+
331
+ # col1, col2 = st.columns(2)
332
+ # with col1:
333
+ # today_closing_price = st.number_input(
334
+ # "Today's Closing Price:",
335
+ # min_value=0.0,
336
+ # value=float(f"{last_known_price_simulated:.2f}"),
337
+ # step=0.1,
338
+ # help="Based on the last simulated price from the historical data."
339
+ # )
340
+ # with col2:
341
+ # today_news_headline = st.text_area(
342
+ # "Today's Financial News Headline:",
343
+ # value="Market shows strong upward momentum, positive outlook for tech sector.",
344
+ # height=100
345
+ # )
346
+
347
+ # if st.button("Predict Price (Linear Regression)"):
348
+ # if today_closing_price is not None and today_news_headline:
349
+ # latest_sentiment_score, latest_sentiment_label = get_sentiment_score_and_label(today_news_headline)
350
+ # st.write(f"**Analyzed Sentiment for Today's News:** `{latest_sentiment_label.upper()}` (Score: `{latest_sentiment_score:.3f}`)")
351
+
352
+ # # Prepare data for prediction
353
+ # new_data_for_prediction = pd.DataFrame({
354
+ # 'Previous_Day_Price': [today_closing_price],
355
+ # 'Previous_Day_Sentiment': [latest_sentiment_score]
356
+ # })
357
+
358
+ # # Make prediction
359
+ # tomorrow_predicted_price = model_prediction.predict(new_data_for_prediction)[0]
360
+
361
+ # st.success(f"**Conceptual Prediction for Tomorrow's Price:** `${tomorrow_predicted_price:.2f}`")
362
+ # st.info("Remember, this is a conceptual prediction based on a simplified model and simulated data.")
363
+ # else:
364
+ # st.warning("Please enter both today's closing price and news headline to make a prediction.")
365
+
366
+ # with tab4:
367
+ # st.header("LLM Forecaster (Conceptual)")
368
+ # st.write("This section demonstrates how a FinGPT forecaster LLM *could* generate a text-based forecast.")
369
+ # st.warning("Note: This model requires significant memory (GPU recommended) and its output is text-based analysis, not a precise numerical prediction like the Linear Regression model.")
370
+
371
+ # llm_ticker = st.text_input("Stock Ticker (e.g., AAPL, MSFT):", value="AAPL")
372
+ # llm_current_price = st.number_input("Current Price:", min_value=0.0, value=175.0, step=0.1)
373
+ # llm_news_summary = st.text_area("Recent News Summary:", value="Apple announced strong Q4 earnings, beating analyst expectations and showing robust iPhone sales in emerging markets.", height=100)
374
+ # llm_current_date = st.date_input("Current Date:", value=pd.to_datetime('2024-07-15'))
375
+
376
+ # if st.button("Get LLM Forecast"):
377
+ # if llm_ticker and llm_current_price is not None and llm_news_summary and llm_current_date:
378
+ # with st.spinner("Generating LLM forecast... This may take a moment."):
379
+ # forecast_text = get_llm_forecast(llm_ticker, llm_current_date, llm_news_summary, llm_current_price)
380
+ # st.subheader("LLM's Forecast and Analysis:")
381
+ # st.write(forecast_text)
382
+ # else:
383
+ # st.warning("Please fill in all fields to get an LLM forecast.")
384
+
385
+
386
  import streamlit as st
387
  import pandas as pd
388
  import numpy as np
389
  import torch
390
+ # from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForCausalLM
391
+ # from peft import PeftModel
392
  from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForCausalLM
393
+ from peft import PeftModel, LoraConfig, get_peft_model # Ensure PeftModel is imported
394
  from sklearn.linear_model import LinearRegression
395
  from sklearn.model_selection import train_test_split
396
  from sklearn.metrics import mean_squared_error, r2_score
 
403
 
404
  # Use st.cache_resource to load heavy models only once
405
  @st.cache_resource
406
+ def load_sentiment_model(sentiment_lora_name="FinGPT/fingpt-sentiment_llama2-13b_lora",
407
+ base_model_name="meta-llama/Llama-2-13b-chat-hf"):
408
  """
409
+ Loads the pre-trained sentiment analysis model (base + LoRA) and tokenizer.
410
  Uses st.cache_resource to prevent reloading on every Streamlit rerun.
411
  """
412
+ st.write(f"Loading base tokenizer from: {base_model_name}...")
413
+ tokenizer = AutoTokenizer.from_pretrained(base_model_name)
414
+ if tokenizer.pad_token is None:
415
+ tokenizer.pad_token = tokenizer.eos_token
416
 
417
+ st.write(f"Loading base model for sentiment: {base_model_name}...")
418
+ # Load the base Llama 2 model as AutoModelForCausalLM
419
+ base_model = AutoModelForCausalLM.from_pretrained(
420
+ base_model_name,
421
+ torch_dtype=torch.float16, # Or bfloat16 if your GPU supports it
422
+ device_map="auto"
423
+ )
424
+
425
+ st.write(f"Loading sentiment LoRA adapter: {sentiment_lora_name}...")
426
+ # Load the LoRA adapter on top of the base model
427
+ sentiment_model = PeftModel.from_pretrained(base_model, sentiment_lora_name)
428
+ sentiment_model.eval()
429
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
430
+ sentiment_model.to(device)
431
  st.write(f"Sentiment model loaded. Using device: {device}")
432
+
433
+ return tokenizer, sentiment_model, device
434
 
435
  tokenizer, sentiment_model, device = load_sentiment_model()
436