vashu2425's picture
Update llm_inference.py
ba7a5fc verified
"""
LLM Inference Module
This module handles all interactions with the Groq API via LangChain,
allowing the application to generate EDA insights and feature engineering
recommendations from dataset analysis.
"""
import os
from dotenv import load_dotenv
import logging
import time
from typing import Dict, Any, List, Optional
from langchain_community.callbacks.manager import get_openai_callback
# LangChain imports
from langchain_groq import ChatGroq
from langchain_core.messages import HumanMessage
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain_community.callbacks.manager import get_openai_callback
from langchain_core.runnables import RunnableSequence
# Configure logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
# Load environment variables
load_dotenv()
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
if not GROQ_API_KEY:
raise ValueError("GROQ_API_KEY not found in environment variables. Please add it to your .env file.")
# Create LLM model
try:
llm = ChatGroq(model_name="llama3-8b-8192", groq_api_key=GROQ_API_KEY)
logger.info("Successfully initialized Groq client")
except Exception as e:
logger.error(f"Failed to initialize Groq client: {str(e)}")
raise
class LLMInference:
"""Class for interacting with LLM via Groq API using LangChain"""
def __init__(self, model_id: str = "llama3-8b-8192"):
"""Initialize the LLM inference class with Groq model"""
self.model_id = model_id
self.llm = llm
# Initialize prompt templates and chains
self._init_prompt_templates()
self._init_chains()
logger.info(f"LLMInference initialized with model: {model_id}")
def _init_prompt_templates(self):
"""Initialize all prompt templates"""
# EDA insights prompt template
self.eda_prompt_template = ChatPromptTemplate.from_messages([
HumanMessagePromptTemplate.from_template(
"""You are a data scientist tasked with performing Exploratory Data Analysis (EDA) on a dataset.
Based on the following dataset information, provide comprehensive EDA insights:
Dataset Information:
- Shape: {shape}
- Columns and their types:
{columns_info}
- Missing values:
{missing_info}
- Basic statistics:
{basic_stats}
- Top correlations:
{correlations}
- Sample data:
{sample_data}
Please provide a detailed EDA analysis that includes:
1. Summary of the dataset (what it appears to be about, key features, etc.)
2. Distribution analysis of key variables
3. Relationship analysis between variables
4. Identification of patterns, outliers, or anomalies
5. Recommended visualizations that would be insightful
6. Initial hypotheses based on the data
Your analysis should be structured, thorough, and provide actionable insights for further investigation.
"""
)
])
# Feature engineering prompt template
self.feature_engineering_prompt_template = ChatPromptTemplate.from_messages([
HumanMessagePromptTemplate.from_template(
"""You are a machine learning engineer specializing in feature engineering.
Based on the following dataset information, provide recommendations for feature engineering:
Dataset Information:
- Shape: {shape}
- Columns and their types:
{columns_info}
- Basic statistics:
{basic_stats}
- Top correlations:
{correlations}
Please provide comprehensive feature engineering recommendations that include:
1. Numerical feature transformations (scaling, normalization, log transforms, etc.)
2. Categorical feature encoding strategies
3. Feature interaction suggestions
4. Dimensionality reduction approaches if applicable
5. Time-based feature creation if applicable
6. Text processing techniques if there are text fields
7. Feature selection recommendations
For each recommendation, explain why it would be beneficial and how it could improve model performance.
Be specific to this dataset's characteristics rather than providing generic advice.
"""
)
])
# Data quality prompt template
self.data_quality_prompt_template = ChatPromptTemplate.from_messages([
HumanMessagePromptTemplate.from_template(
"""You are a data quality expert.
Based on the following dataset information, provide data quality insights and recommendations:
Dataset Information:
- Shape: {shape}
- Columns and their types:
{columns_info}
- Missing values:
{missing_info}
- Basic statistics:
{basic_stats}
Please provide a comprehensive data quality assessment that includes:
1. Assessment of data completeness (missing values)
2. Identification of potential data inconsistencies or errors
3. Recommendations for data cleaning and preprocessing
4. Advice on handling outliers
5. Suggestions for data validation checks
6. Recommendations to improve data quality
Your assessment should be specific to this dataset and provide actionable recommendations.
"""
)
])
# QA prompt template
self.qa_prompt_template = ChatPromptTemplate.from_messages([
HumanMessagePromptTemplate.from_template(
"""You are a data scientist answering questions about a dataset.
Based on the following dataset information, please answer the user's question:
Dataset Information:
- Shape: {shape}
- Columns and their types:
{columns_info}
- Basic statistics:
{basic_stats}
User's question: {question}
Please provide a clear, informative answer to the user's question based on the dataset information provided.
"""
)
])
def _init_chains(self):
"""Initialize all chains using modern RunnableSequence pattern"""
# EDA insights chain
self.eda_chain = self.eda_prompt_template | self.llm
# Feature engineering chain
self.feature_engineering_chain = self.feature_engineering_prompt_template | self.llm
# Data quality chain
self.data_quality_chain = self.data_quality_prompt_template | self.llm
# QA chain
self.qa_chain = self.qa_prompt_template | self.llm
def _format_columns_info(self, columns: List[str], dtypes: Dict[str, str]) -> str:
"""Format columns info for prompt"""
return "\n".join([f"- {col} ({dtypes.get(col, 'unknown')})" for col in columns])
def _format_missing_info(self, missing_values: Dict[str, tuple]) -> str:
"""Format missing values info for prompt"""
missing_info = "\n".join([f"- {col}: {count} missing values ({percent}%)"
for col, (count, percent) in missing_values.items() if count > 0])
if not missing_info:
missing_info = "No missing values detected."
return missing_info
def _execute_chain(
self,
chain: RunnableSequence,
input_data: Dict[str, Any],
operation_name: str
) -> str:
"""
Execute a chain with tracking and error handling
Args:
chain: The LangChain chain to execute
input_data: The input data for the chain
operation_name: Name of the operation for logging
Returns:
str: The generated text
"""
try:
start_time = time.time()
with get_openai_callback() as cb:
result = chain.invoke(input_data).content
elapsed_time = time.time() - start_time
logger.info(f"{operation_name} generated in {elapsed_time:.2f} seconds")
logger.info(f"Tokens used: {cb.total_tokens}, "
f"Prompt tokens: {cb.prompt_tokens}, "
f"Completion tokens: {cb.completion_tokens}")
return result
except Exception as e:
error_msg = f"Error executing {operation_name.lower()}: {str(e)}"
logger.error(error_msg)
return error_msg
def generate_eda_insights(self, dataset_info: Dict[str, Any]) -> str:
"""
Generate EDA insights based on dataset information using LangChain
Args:
dataset_info: Dictionary containing dataset analysis
Returns:
str: Detailed EDA insights and recommendations
"""
logger.info("Generating EDA insights")
# Format the input data
columns_info = self._format_columns_info(
dataset_info.get("columns", []),
dataset_info.get("dtypes", {})
)
missing_info = self._format_missing_info(
dataset_info.get("missing_values", {})
)
# Prepare input for the chain
input_data = {
"shape": dataset_info.get("shape", "N/A"),
"columns_info": columns_info,
"missing_info": missing_info,
"basic_stats": dataset_info.get("basic_stats", ""),
"correlations": dataset_info.get("correlations", ""),
"sample_data": dataset_info.get("sample_data", "N/A")
}
return self._execute_chain(self.eda_chain, input_data, "EDA insights")
def generate_feature_engineering_recommendations(self, dataset_info: Dict[str, Any]) -> str:
"""
Generate feature engineering recommendations based on dataset information using LangChain
Args:
dataset_info: Dictionary containing dataset analysis
Returns:
str: Feature engineering recommendations
"""
logger.info("Generating feature engineering recommendations")
# Format the input data
columns_info = self._format_columns_info(
dataset_info.get("columns", []),
dataset_info.get("dtypes", {})
)
# Prepare input for the chain
input_data = {
"shape": dataset_info.get("shape", "N/A"),
"columns_info": columns_info,
"basic_stats": dataset_info.get("basic_stats", ""),
"correlations": dataset_info.get("correlations", "")
}
return self._execute_chain(
self.feature_engineering_chain,
input_data,
"Feature engineering recommendations"
)
def generate_data_quality_insights(self, dataset_info: Dict[str, Any]) -> str:
"""
Generate data quality insights based on dataset information using LangChain
Args:
dataset_info: Dictionary containing dataset analysis
Returns:
str: Data quality insights and improvement recommendations
"""
logger.info("Generating data quality insights")
# Format the input data
columns_info = self._format_columns_info(
dataset_info.get("columns", []),
dataset_info.get("dtypes", {})
)
missing_info = self._format_missing_info(
dataset_info.get("missing_values", {})
)
# Prepare input for the chain
input_data = {
"shape": dataset_info.get("shape", "N/A"),
"columns_info": columns_info,
"missing_info": missing_info,
"basic_stats": dataset_info.get("basic_stats", "")
}
return self._execute_chain(
self.data_quality_chain,
input_data,
"Data quality insights"
)
def answer_dataset_question(self, question: str, dataset_info: Dict[str, Any]) -> str:
"""
Answer a specific question about the dataset using LangChain
Args:
question: User's question about the dataset
dataset_info: Dictionary containing dataset analysis
Returns:
str: Answer to the user's question
"""
logger.info(f"Answering dataset question: {question[:50]}...")
# Format the input data
columns_info = self._format_columns_info(
dataset_info.get("columns", []),
dataset_info.get("dtypes", {})
)
# Prepare input for the chain
input_data = {
"shape": dataset_info.get("shape", "N/A"),
"columns_info": columns_info,
"basic_stats": dataset_info.get("basic_stats", ""),
"question": question
}
return self._execute_chain(
self.qa_chain,
input_data,
"Answer"
)
def answer_with_memory(self, question: str, dataset_info: Dict[str, Any], memory) -> str:
"""
Answer a question with conversation memory to maintain context
Args:
question: User's question about the dataset
dataset_info: Dictionary containing dataset analysis
memory: ConversationBufferMemory instance to store conversation history
Returns:
str: Answer to the user's question with conversation context
"""
logger.info(f"Answering with memory: {question[:50]}...")
# Format the input data for the dataset context
columns_info = self._format_columns_info(
dataset_info.get("columns", []),
dataset_info.get("dtypes", {})
)
# Create a custom prompt that includes both conversation history and dataset info
memory_prompt = ChatPromptTemplate.from_messages([
HumanMessagePromptTemplate.from_template(
"""You are a data scientist answering questions about a dataset.
The following is information about the dataset:
Dataset Information:
- Shape: {shape}
- Columns and their types:
{columns_info}
- Basic statistics:
{basic_stats}
Previous conversation:
{chat_history}
User's new question: {question}
Please provide a clear, informative answer to the user's question. Take into account the previous conversation for context. Make your answer specific to the dataset information provided."""
)
])
# Create a chain that uses both the prompt and memory
memory_chain = memory_prompt | self.llm
# Prepare the input data including memory retrieved from conversation_memory
try:
chat_history = memory.load_memory_variables({})["chat_history"]
# Format chat history into a string
chat_history_str = "\n".join([f"{msg.type}: {msg.content}" for msg in chat_history])
except Exception as e:
logger.warning(f"Error loading memory: {str(e)}. Using empty chat history.")
chat_history_str = "No previous conversation."
input_data = {
"shape": dataset_info.get("shape", "N/A"),
"columns_info": columns_info,
"basic_stats": dataset_info.get("basic_stats", ""),
"question": question,
"chat_history": chat_history_str
}
# Execute the chain and get a response
response = self._execute_chain(
memory_chain,
input_data,
"Answer with memory"
)
# Save the interaction to memory
memory.save_context(
{"input": question},
{"output": response}
)
return response