Todlong / utils /llm_analyzer.py
cwpkd's picture
Update utils/llm_analyzer.py
cc84412 verified
# utils/llm_analyzer.py
"""
LLM-based analysis using Gemma model
"""
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import List, Dict
from config import GEMMA_MODEL, LLM_MAX_LENGTH, LLM_TEMPERATURE, LLM_TOP_P
class LLMAnalyzer:
"""Analyze and summarize using Gemma LLM"""
def __init__(self):
"""Initialize Gemma model"""
import os
print("Loading Gemma model...")
# Get token from environment
hf_token = os.environ.get("HF_TOKEN", None)
self.tokenizer = AutoTokenizer.from_pretrained(
GEMMA_MODEL,
token=hf_token
)
self.model = AutoModelForCausalLM.from_pretrained(
GEMMA_MODEL,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto",
token=hf_token
)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Gemma loaded on {self.device}!")
def generate_response(self, prompt: str, max_length: int = LLM_MAX_LENGTH) -> str:
"""
Generate response from LLM
Args:
prompt: Input prompt
max_length: Maximum length of generated text
Returns:
Generated text
"""
# Format prompt (works for both Gemma and Zephyr)
if "gemma" in GEMMA_MODEL.lower():
formatted_prompt = f"<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
else:
# Zephyr format
formatted_prompt = f"<|user|>\n{prompt}</s>\n<|assistant|>\n"
inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_length,
temperature=LLM_TEMPERATURE,
top_p=LLM_TOP_P,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id
)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the model's response
if "<start_of_turn>model" in response:
response = response.split("<start_of_turn>model")[-1].strip()
elif "<|assistant|>" in response:
response = response.split("<|assistant|>")[-1].strip()
# Remove the original prompt if still present
if prompt in response:
response = response.replace(prompt, "").strip()
return response
def summarize_news(self, articles: List[Dict]) -> str:
"""
Summarize news articles
Args:
articles: List of article dictionaries
Returns:
Summary text
"""
# Prepare articles text
articles_text = ""
for i, article in enumerate(articles[:5], 1): # Limit to 5 articles
articles_text += f"{i}. {article['title']}\n"
if 'summary' in article:
articles_text += f" {article['summary'][:200]}...\n\n"
prompt = f"""Analyze these financial news headlines and provide a brief market summary (2-3 sentences):
{articles_text}
Summary:"""
return self.generate_response(prompt, max_length=200)
def analyze_sentiment_context(self, article: Dict, sentiment_data: Dict) -> str:
"""
Provide context for sentiment analysis
Args:
article: Article dictionary
sentiment_data: Sentiment analysis results
Returns:
Analysis text
"""
sentiment_label = sentiment_data['sentiment_label']
confidence = sentiment_data['confidence']
prompt = f"""As a financial analyst, explain why this news headline has a {sentiment_label.lower()} sentiment (confidence: {confidence:.2%}):
Headline: {article['title']}
Summary: {article.get('summary', 'N/A')[:200]}
Provide a brief explanation (2-3 sentences):"""
return self.generate_response(prompt, max_length=150)
def generate_investment_insight(self, symbol: str, articles: List[Dict], sentiments: List[Dict]) -> str:
"""
Generate investment insights based on news and sentiment
Args:
symbol: Stock ticker symbol
articles: List of articles
sentiments: List of sentiment analyses
Returns:
Investment insight text
"""
# Calculate average sentiment
avg_sentiment = sum(s['combined_score'] for s in sentiments) / len(sentiments)
# Count sentiment distribution
positive = sum(1 for s in sentiments if s['sentiment_label'] == 'Positive')
negative = sum(1 for s in sentiments if s['sentiment_label'] == 'Negative')
neutral = len(sentiments) - positive - negative
# Prepare recent headlines
headlines = "\n".join([f"- {a['title']}" for a in articles[:3]])
prompt = f"""As a financial advisor, provide investment insights for {symbol} based on recent news sentiment:
Recent Headlines:
{headlines}
Sentiment Analysis:
- Positive: {positive}/{len(sentiments)}
- Negative: {negative}/{len(sentiments)}
- Neutral: {neutral}/{len(sentiments)}
- Average Score: {avg_sentiment:.2f}
Provide brief investment insights (3-4 sentences):"""
return self.generate_response(prompt, max_length=250)