|
|
|
|
|
""" |
|
|
LLM-based analysis using Gemma model |
|
|
""" |
|
|
|
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
from typing import List, Dict |
|
|
from config import GEMMA_MODEL, LLM_MAX_LENGTH, LLM_TEMPERATURE, LLM_TOP_P |
|
|
|
|
|
|
|
|
class LLMAnalyzer: |
|
|
"""Analyze and summarize using Gemma LLM""" |
|
|
|
|
|
def __init__(self): |
|
|
"""Initialize Gemma model""" |
|
|
import os |
|
|
print("Loading Gemma model...") |
|
|
|
|
|
|
|
|
hf_token = os.environ.get("HF_TOKEN", None) |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
|
GEMMA_MODEL, |
|
|
token=hf_token |
|
|
) |
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
GEMMA_MODEL, |
|
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
|
device_map="auto", |
|
|
token=hf_token |
|
|
) |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print(f"Gemma loaded on {self.device}!") |
|
|
|
|
|
def generate_response(self, prompt: str, max_length: int = LLM_MAX_LENGTH) -> str: |
|
|
""" |
|
|
Generate response from LLM |
|
|
|
|
|
Args: |
|
|
prompt: Input prompt |
|
|
max_length: Maximum length of generated text |
|
|
|
|
|
Returns: |
|
|
Generated text |
|
|
""" |
|
|
|
|
|
if "gemma" in GEMMA_MODEL.lower(): |
|
|
formatted_prompt = f"<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n" |
|
|
else: |
|
|
|
|
|
formatted_prompt = f"<|user|>\n{prompt}</s>\n<|assistant|>\n" |
|
|
|
|
|
inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_length, |
|
|
temperature=LLM_TEMPERATURE, |
|
|
top_p=LLM_TOP_P, |
|
|
do_sample=True, |
|
|
pad_token_id=self.tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
if "<start_of_turn>model" in response: |
|
|
response = response.split("<start_of_turn>model")[-1].strip() |
|
|
elif "<|assistant|>" in response: |
|
|
response = response.split("<|assistant|>")[-1].strip() |
|
|
|
|
|
|
|
|
if prompt in response: |
|
|
response = response.replace(prompt, "").strip() |
|
|
|
|
|
return response |
|
|
|
|
|
def summarize_news(self, articles: List[Dict]) -> str: |
|
|
""" |
|
|
Summarize news articles |
|
|
|
|
|
Args: |
|
|
articles: List of article dictionaries |
|
|
|
|
|
Returns: |
|
|
Summary text |
|
|
""" |
|
|
|
|
|
articles_text = "" |
|
|
for i, article in enumerate(articles[:5], 1): |
|
|
articles_text += f"{i}. {article['title']}\n" |
|
|
if 'summary' in article: |
|
|
articles_text += f" {article['summary'][:200]}...\n\n" |
|
|
|
|
|
prompt = f"""Analyze these financial news headlines and provide a brief market summary (2-3 sentences): |
|
|
|
|
|
{articles_text} |
|
|
|
|
|
Summary:""" |
|
|
|
|
|
return self.generate_response(prompt, max_length=200) |
|
|
|
|
|
def analyze_sentiment_context(self, article: Dict, sentiment_data: Dict) -> str: |
|
|
""" |
|
|
Provide context for sentiment analysis |
|
|
|
|
|
Args: |
|
|
article: Article dictionary |
|
|
sentiment_data: Sentiment analysis results |
|
|
|
|
|
Returns: |
|
|
Analysis text |
|
|
""" |
|
|
sentiment_label = sentiment_data['sentiment_label'] |
|
|
confidence = sentiment_data['confidence'] |
|
|
|
|
|
prompt = f"""As a financial analyst, explain why this news headline has a {sentiment_label.lower()} sentiment (confidence: {confidence:.2%}): |
|
|
|
|
|
Headline: {article['title']} |
|
|
Summary: {article.get('summary', 'N/A')[:200]} |
|
|
|
|
|
Provide a brief explanation (2-3 sentences):""" |
|
|
|
|
|
return self.generate_response(prompt, max_length=150) |
|
|
|
|
|
def generate_investment_insight(self, symbol: str, articles: List[Dict], sentiments: List[Dict]) -> str: |
|
|
""" |
|
|
Generate investment insights based on news and sentiment |
|
|
|
|
|
Args: |
|
|
symbol: Stock ticker symbol |
|
|
articles: List of articles |
|
|
sentiments: List of sentiment analyses |
|
|
|
|
|
Returns: |
|
|
Investment insight text |
|
|
""" |
|
|
|
|
|
avg_sentiment = sum(s['combined_score'] for s in sentiments) / len(sentiments) |
|
|
|
|
|
|
|
|
positive = sum(1 for s in sentiments if s['sentiment_label'] == 'Positive') |
|
|
negative = sum(1 for s in sentiments if s['sentiment_label'] == 'Negative') |
|
|
neutral = len(sentiments) - positive - negative |
|
|
|
|
|
|
|
|
headlines = "\n".join([f"- {a['title']}" for a in articles[:3]]) |
|
|
|
|
|
prompt = f"""As a financial advisor, provide investment insights for {symbol} based on recent news sentiment: |
|
|
|
|
|
Recent Headlines: |
|
|
{headlines} |
|
|
|
|
|
Sentiment Analysis: |
|
|
- Positive: {positive}/{len(sentiments)} |
|
|
- Negative: {negative}/{len(sentiments)} |
|
|
- Neutral: {neutral}/{len(sentiments)} |
|
|
- Average Score: {avg_sentiment:.2f} |
|
|
|
|
|
Provide brief investment insights (3-4 sentences):""" |
|
|
|
|
|
return self.generate_response(prompt, max_length=250) |