CFA_Ai_Agent / agent.py
Navada25
Deploy CFA AI Agent with Finance-Llama-8B
ce180e5
"""
CFA AI Agent - LangChain Agent Setup
This module sets up the LangChain agent with Finance-Llama-8B model and financial tools.
"""
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain_community.llms import HuggingFacePipeline
from langchain.agents import initialize_agent, AgentType
from langchain.memory import ConversationBufferMemory
from langchain.schema import SystemMessage
from langchain.prompts import MessagesPlaceholder
from typing import List, Any, Optional
# Import our custom tools
from tools.finance_tools import (
calculate_dcf,
calculate_sharpe_ratio,
compare_pe_ratios,
calculate_beta,
calculate_wacc,
financial_ratios_analysis
)
from tools.data_fetcher import (
get_stock_price,
get_historical_data,
get_company_info,
get_financial_statements,
get_market_indices,
compare_stocks
)
class CFAAgent:
"""
CFA AI Agent that combines Finance-Llama-8B model with financial analysis tools.
"""
def __init__(self, model_name: str = "tarun7r/Finance-Llama-8B"):
"""
Initialize the CFA Agent with model and tools.
Args:
model_name: Hugging Face model name for financial analysis
"""
self.model_name = model_name
self.tokenizer = None
self.model = None
self.llm = None
self.agent = None
self.memory = None
self._setup_model()
self._setup_tools()
self._setup_agent()
def _setup_model(self):
"""Load and setup the Finance-Llama-8B model."""
try:
print(f"Loading model: {self.model_name}")
# Check if CUDA is available
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
trust_remote_code=True
)
# Add pad token if not present
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Load model with appropriate settings and memory optimization
if device == "cuda":
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
low_cpu_mem_usage=True,
load_in_8bit=True, # Enable 8-bit quantization for memory efficiency
max_memory={0: "6GB"} # Limit GPU memory usage
)
else:
# For CPU, use aggressive memory optimization
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
trust_remote_code=True,
low_cpu_mem_usage=True,
torch_dtype=torch.float32,
device_map="cpu",
max_memory={"cpu": "8GB"} # Limit CPU memory usage
)
# Create pipeline
pipe = pipeline(
"text-generation",
model=self.model,
tokenizer=self.tokenizer,
max_new_tokens=512,
temperature=0.1,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id,
repetition_penalty=1.1
)
# Wrap in LangChain
self.llm = HuggingFacePipeline(pipeline=pipe)
print("βœ… Model loaded successfully")
except Exception as e:
print(f"❌ Error loading model: {str(e)}")
# Fallback to a smaller model or OpenAI if Finance-Llama-8B fails
self._setup_fallback_model()
def _setup_fallback_model(self):
"""Setup a fallback model if Finance-Llama-8B fails to load."""
try:
print("Setting up fallback model...")
from langchain_community.llms import OpenAI
# Check for OpenAI API key
if os.getenv("OPENAI_API_KEY"):
self.llm = OpenAI(
temperature=0.1,
model_name="gpt-3.5-turbo-instruct",
max_tokens=512
)
print("βœ… Using OpenAI GPT-3.5 as fallback")
else:
raise ValueError("No OpenAI API key found")
except Exception as e:
print(f"❌ Fallback model failed: {str(e)}")
# Last resort: use a very small local model
try:
pipe = pipeline(
"text-generation",
model="distilgpt2",
max_new_tokens=256,
temperature=0.7
)
self.llm = HuggingFacePipeline(pipeline=pipe)
print("βœ… Using DistilGPT2 as emergency fallback")
except Exception as final_e:
raise RuntimeError(f"All model loading attempts failed: {final_e}")
def _setup_tools(self):
"""Setup all available financial analysis tools."""
self.tools = [
# Finance calculation tools
calculate_dcf,
calculate_sharpe_ratio,
compare_pe_ratios,
calculate_beta,
calculate_wacc,
financial_ratios_analysis,
# Data fetching tools
get_stock_price,
get_historical_data,
get_company_info,
get_financial_statements,
get_market_indices,
compare_stocks
]
print(f"βœ… Loaded {len(self.tools)} financial analysis tools")
def _setup_agent(self):
"""Setup the LangChain agent with memory and tools."""
try:
# Setup conversation memory
self.memory = ConversationBufferMemory(
memory_key="chat_history",
return_messages=True,
output_key="output"
)
# Initialize agent
self.agent = initialize_agent(
tools=self.tools,
llm=self.llm,
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
memory=self.memory,
verbose=True,
handle_parsing_errors=True,
max_iterations=3,
early_stopping_method="generate"
)
# Add system message for financial context
system_message = """You are a CFA (Chartered Financial Analyst) AI assistant specialized in financial analysis, investment valuation, and portfolio management.
Your expertise includes:
- Financial statement analysis and ratio calculations
- Valuation models (DCF, comparable company analysis, etc.)
- Risk assessment and portfolio theory
- Market analysis and economic indicators
- Investment recommendations based on fundamental analysis
When answering questions:
1. Use the available financial tools to fetch real data when needed
2. Provide clear, professional explanations suitable for CFA-level analysis
3. Show your calculations and reasoning
4. Consider both quantitative and qualitative factors
5. Acknowledge limitations and assumptions in your analysis
You have access to real-time financial data and calculation tools. Use them to provide accurate, data-driven insights."""
# Store system message for context
self.system_message = system_message
print("βœ… Agent initialized successfully")
except Exception as e:
print(f"❌ Error setting up agent: {str(e)}")
raise
def query(self, question: str) -> str:
"""
Process a financial query using the CFA agent.
Args:
question: User's financial question or request
Returns:
Agent's response with analysis and recommendations
"""
try:
# Enhance the question with context
enhanced_question = f"""As a CFA analyst, please help with the following:
{question}
Please provide a thorough analysis using available data and financial tools. Show your work and explain your reasoning."""
# Get response from agent
response = self.agent.run(enhanced_question)
return response
except Exception as e:
error_msg = f"Error processing query: {str(e)}"
print(error_msg)
return error_msg
def get_conversation_history(self) -> List[Any]:
"""Get the current conversation history."""
if self.memory:
return self.memory.chat_memory.messages
return []
def clear_memory(self):
"""Clear the conversation memory."""
if self.memory:
self.memory.clear()
print("βœ… Conversation memory cleared")
def get_available_tools(self) -> List[str]:
"""Get list of available tool names."""
return [tool.name for tool in self.tools]
def health_check(self) -> dict:
"""Perform a health check of the agent components."""
status = {
"model_loaded": self.model is not None,
"llm_ready": self.llm is not None,
"agent_ready": self.agent is not None,
"memory_ready": self.memory is not None,
"tools_count": len(self.tools),
"device": "cuda" if torch.cuda.is_available() else "cpu"
}
return status
def create_cfa_agent(model_name: str = "tarun7r/Finance-Llama-8B") -> CFAAgent:
"""
Factory function to create and return a CFA Agent instance.
Args:
model_name: Hugging Face model name for financial analysis
Returns:
Initialized CFAAgent instance
"""
try:
agent = CFAAgent(model_name=model_name)
print("🎯 CFA Agent created successfully")
return agent
except Exception as e:
print(f"❌ Failed to create CFA Agent: {str(e)}")
raise
# Example usage and testing
if __name__ == "__main__":
print("πŸš€ Initializing CFA AI Agent...")
try:
# Create agent
cfa_agent = create_cfa_agent()
# Health check
health = cfa_agent.health_check()
print("πŸ“Š Health Check Results:")
for key, value in health.items():
status = "βœ…" if value else "❌"
print(f" {status} {key}: {value}")
# Test queries
test_queries = [
"What is the current stock price of Apple (AAPL)?",
"Calculate the PE ratio comparison between Apple and Microsoft",
"Explain the CAPM model in simple terms"
]
print("\nπŸ§ͺ Running test queries...")
for i, query in enumerate(test_queries, 1):
print(f"\n--- Test Query {i} ---")
print(f"Q: {query}")
try:
response = cfa_agent.query(query)
print(f"A: {response}")
except Exception as e:
print(f"❌ Query failed: {str(e)}")
except Exception as e:
print(f"❌ CFA Agent initialization failed: {str(e)}")