|
|
""" |
|
|
LLM Inference Module |
|
|
|
|
|
This module handles all interactions with the Groq API via LangChain, |
|
|
allowing the application to generate EDA insights and feature engineering |
|
|
recommendations from dataset analysis. |
|
|
""" |
|
|
|
|
|
import os |
|
|
from dotenv import load_dotenv |
|
|
import logging |
|
|
import time |
|
|
from typing import Dict, Any, List, Optional |
|
|
from langchain_community.callbacks.manager import get_openai_callback |
|
|
|
|
|
|
|
|
from langchain_groq import ChatGroq |
|
|
from langchain_core.messages import HumanMessage |
|
|
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate |
|
|
from langchain_community.callbacks.manager import get_openai_callback |
|
|
from langchain_core.runnables import RunnableSequence |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
GROQ_API_KEY = os.getenv("GROQ_API_KEY") |
|
|
|
|
|
if not GROQ_API_KEY: |
|
|
raise ValueError("GROQ_API_KEY not found in environment variables. Please add it to your .env file.") |
|
|
|
|
|
|
|
|
try: |
|
|
llm = ChatGroq(model_name="llama3-8b-8192", groq_api_key=GROQ_API_KEY) |
|
|
logger.info("Successfully initialized Groq client") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to initialize Groq client: {str(e)}") |
|
|
raise |
|
|
|
|
|
class LLMInference: |
|
|
"""Class for interacting with LLM via Groq API using LangChain""" |
|
|
|
|
|
def __init__(self, model_id: str = "llama3-8b-8192"): |
|
|
"""Initialize the LLM inference class with Groq model""" |
|
|
self.model_id = model_id |
|
|
self.llm = llm |
|
|
|
|
|
|
|
|
self._init_prompt_templates() |
|
|
self._init_chains() |
|
|
|
|
|
logger.info(f"LLMInference initialized with model: {model_id}") |
|
|
|
|
|
def _init_prompt_templates(self): |
|
|
"""Initialize all prompt templates""" |
|
|
|
|
|
|
|
|
self.eda_prompt_template = ChatPromptTemplate.from_messages([ |
|
|
HumanMessagePromptTemplate.from_template( |
|
|
"""You are a data scientist tasked with performing Exploratory Data Analysis (EDA) on a dataset. |
|
|
Based on the following dataset information, provide comprehensive EDA insights: |
|
|
|
|
|
Dataset Information: |
|
|
- Shape: {shape} |
|
|
- Columns and their types: |
|
|
{columns_info} |
|
|
|
|
|
- Missing values: |
|
|
{missing_info} |
|
|
|
|
|
- Basic statistics: |
|
|
{basic_stats} |
|
|
|
|
|
- Top correlations: |
|
|
{correlations} |
|
|
|
|
|
- Sample data: |
|
|
{sample_data} |
|
|
|
|
|
Please provide a detailed EDA analysis that includes: |
|
|
|
|
|
1. Summary of the dataset (what it appears to be about, key features, etc.) |
|
|
2. Distribution analysis of key variables |
|
|
3. Relationship analysis between variables |
|
|
4. Identification of patterns, outliers, or anomalies |
|
|
5. Recommended visualizations that would be insightful |
|
|
6. Initial hypotheses based on the data |
|
|
|
|
|
Your analysis should be structured, thorough, and provide actionable insights for further investigation. |
|
|
""" |
|
|
) |
|
|
]) |
|
|
|
|
|
|
|
|
self.feature_engineering_prompt_template = ChatPromptTemplate.from_messages([ |
|
|
HumanMessagePromptTemplate.from_template( |
|
|
"""You are a machine learning engineer specializing in feature engineering. |
|
|
Based on the following dataset information, provide recommendations for feature engineering: |
|
|
|
|
|
Dataset Information: |
|
|
- Shape: {shape} |
|
|
- Columns and their types: |
|
|
{columns_info} |
|
|
|
|
|
- Basic statistics: |
|
|
{basic_stats} |
|
|
|
|
|
- Top correlations: |
|
|
{correlations} |
|
|
|
|
|
Please provide comprehensive feature engineering recommendations that include: |
|
|
|
|
|
1. Numerical feature transformations (scaling, normalization, log transforms, etc.) |
|
|
2. Categorical feature encoding strategies |
|
|
3. Feature interaction suggestions |
|
|
4. Dimensionality reduction approaches if applicable |
|
|
5. Time-based feature creation if applicable |
|
|
6. Text processing techniques if there are text fields |
|
|
7. Feature selection recommendations |
|
|
|
|
|
For each recommendation, explain why it would be beneficial and how it could improve model performance. |
|
|
Be specific to this dataset's characteristics rather than providing generic advice. |
|
|
""" |
|
|
) |
|
|
]) |
|
|
|
|
|
|
|
|
self.data_quality_prompt_template = ChatPromptTemplate.from_messages([ |
|
|
HumanMessagePromptTemplate.from_template( |
|
|
"""You are a data quality expert. |
|
|
Based on the following dataset information, provide data quality insights and recommendations: |
|
|
|
|
|
Dataset Information: |
|
|
- Shape: {shape} |
|
|
- Columns and their types: |
|
|
{columns_info} |
|
|
|
|
|
- Missing values: |
|
|
{missing_info} |
|
|
|
|
|
- Basic statistics: |
|
|
{basic_stats} |
|
|
|
|
|
Please provide a comprehensive data quality assessment that includes: |
|
|
|
|
|
1. Assessment of data completeness (missing values) |
|
|
2. Identification of potential data inconsistencies or errors |
|
|
3. Recommendations for data cleaning and preprocessing |
|
|
4. Advice on handling outliers |
|
|
5. Suggestions for data validation checks |
|
|
6. Recommendations to improve data quality |
|
|
|
|
|
Your assessment should be specific to this dataset and provide actionable recommendations. |
|
|
""" |
|
|
) |
|
|
]) |
|
|
|
|
|
|
|
|
self.qa_prompt_template = ChatPromptTemplate.from_messages([ |
|
|
HumanMessagePromptTemplate.from_template( |
|
|
"""You are a data scientist answering questions about a dataset. |
|
|
Based on the following dataset information, please answer the user's question: |
|
|
|
|
|
Dataset Information: |
|
|
- Shape: {shape} |
|
|
- Columns and their types: |
|
|
{columns_info} |
|
|
|
|
|
- Basic statistics: |
|
|
{basic_stats} |
|
|
|
|
|
User's question: {question} |
|
|
|
|
|
Please provide a clear, informative answer to the user's question based on the dataset information provided. |
|
|
""" |
|
|
) |
|
|
]) |
|
|
|
|
|
def _init_chains(self): |
|
|
"""Initialize all chains using modern RunnableSequence pattern""" |
|
|
|
|
|
|
|
|
self.eda_chain = self.eda_prompt_template | self.llm |
|
|
|
|
|
|
|
|
self.feature_engineering_chain = self.feature_engineering_prompt_template | self.llm |
|
|
|
|
|
|
|
|
self.data_quality_chain = self.data_quality_prompt_template | self.llm |
|
|
|
|
|
|
|
|
self.qa_chain = self.qa_prompt_template | self.llm |
|
|
|
|
|
def _format_columns_info(self, columns: List[str], dtypes: Dict[str, str]) -> str: |
|
|
"""Format columns info for prompt""" |
|
|
return "\n".join([f"- {col} ({dtypes.get(col, 'unknown')})" for col in columns]) |
|
|
|
|
|
def _format_missing_info(self, missing_values: Dict[str, tuple]) -> str: |
|
|
"""Format missing values info for prompt""" |
|
|
missing_info = "\n".join([f"- {col}: {count} missing values ({percent}%)" |
|
|
for col, (count, percent) in missing_values.items() if count > 0]) |
|
|
|
|
|
if not missing_info: |
|
|
missing_info = "No missing values detected." |
|
|
|
|
|
return missing_info |
|
|
|
|
|
def _execute_chain( |
|
|
self, |
|
|
chain: RunnableSequence, |
|
|
input_data: Dict[str, Any], |
|
|
operation_name: str |
|
|
) -> str: |
|
|
""" |
|
|
Execute a chain with tracking and error handling |
|
|
|
|
|
Args: |
|
|
chain: The LangChain chain to execute |
|
|
input_data: The input data for the chain |
|
|
operation_name: Name of the operation for logging |
|
|
|
|
|
Returns: |
|
|
str: The generated text |
|
|
""" |
|
|
try: |
|
|
start_time = time.time() |
|
|
with get_openai_callback() as cb: |
|
|
result = chain.invoke(input_data).content |
|
|
elapsed_time = time.time() - start_time |
|
|
|
|
|
logger.info(f"{operation_name} generated in {elapsed_time:.2f} seconds") |
|
|
logger.info(f"Tokens used: {cb.total_tokens}, " |
|
|
f"Prompt tokens: {cb.prompt_tokens}, " |
|
|
f"Completion tokens: {cb.completion_tokens}") |
|
|
|
|
|
return result |
|
|
except Exception as e: |
|
|
error_msg = f"Error executing {operation_name.lower()}: {str(e)}" |
|
|
logger.error(error_msg) |
|
|
return error_msg |
|
|
|
|
|
def generate_eda_insights(self, dataset_info: Dict[str, Any]) -> str: |
|
|
""" |
|
|
Generate EDA insights based on dataset information using LangChain |
|
|
|
|
|
Args: |
|
|
dataset_info: Dictionary containing dataset analysis |
|
|
|
|
|
Returns: |
|
|
str: Detailed EDA insights and recommendations |
|
|
""" |
|
|
logger.info("Generating EDA insights") |
|
|
|
|
|
|
|
|
columns_info = self._format_columns_info( |
|
|
dataset_info.get("columns", []), |
|
|
dataset_info.get("dtypes", {}) |
|
|
) |
|
|
|
|
|
missing_info = self._format_missing_info( |
|
|
dataset_info.get("missing_values", {}) |
|
|
) |
|
|
|
|
|
|
|
|
input_data = { |
|
|
"shape": dataset_info.get("shape", "N/A"), |
|
|
"columns_info": columns_info, |
|
|
"missing_info": missing_info, |
|
|
"basic_stats": dataset_info.get("basic_stats", ""), |
|
|
"correlations": dataset_info.get("correlations", ""), |
|
|
"sample_data": dataset_info.get("sample_data", "N/A") |
|
|
} |
|
|
|
|
|
return self._execute_chain(self.eda_chain, input_data, "EDA insights") |
|
|
|
|
|
def generate_feature_engineering_recommendations(self, dataset_info: Dict[str, Any]) -> str: |
|
|
""" |
|
|
Generate feature engineering recommendations based on dataset information using LangChain |
|
|
|
|
|
Args: |
|
|
dataset_info: Dictionary containing dataset analysis |
|
|
|
|
|
Returns: |
|
|
str: Feature engineering recommendations |
|
|
""" |
|
|
logger.info("Generating feature engineering recommendations") |
|
|
|
|
|
|
|
|
columns_info = self._format_columns_info( |
|
|
dataset_info.get("columns", []), |
|
|
dataset_info.get("dtypes", {}) |
|
|
) |
|
|
|
|
|
|
|
|
input_data = { |
|
|
"shape": dataset_info.get("shape", "N/A"), |
|
|
"columns_info": columns_info, |
|
|
"basic_stats": dataset_info.get("basic_stats", ""), |
|
|
"correlations": dataset_info.get("correlations", "") |
|
|
} |
|
|
|
|
|
return self._execute_chain( |
|
|
self.feature_engineering_chain, |
|
|
input_data, |
|
|
"Feature engineering recommendations" |
|
|
) |
|
|
|
|
|
def generate_data_quality_insights(self, dataset_info: Dict[str, Any]) -> str: |
|
|
""" |
|
|
Generate data quality insights based on dataset information using LangChain |
|
|
|
|
|
Args: |
|
|
dataset_info: Dictionary containing dataset analysis |
|
|
|
|
|
Returns: |
|
|
str: Data quality insights and improvement recommendations |
|
|
""" |
|
|
logger.info("Generating data quality insights") |
|
|
|
|
|
|
|
|
columns_info = self._format_columns_info( |
|
|
dataset_info.get("columns", []), |
|
|
dataset_info.get("dtypes", {}) |
|
|
) |
|
|
|
|
|
missing_info = self._format_missing_info( |
|
|
dataset_info.get("missing_values", {}) |
|
|
) |
|
|
|
|
|
|
|
|
input_data = { |
|
|
"shape": dataset_info.get("shape", "N/A"), |
|
|
"columns_info": columns_info, |
|
|
"missing_info": missing_info, |
|
|
"basic_stats": dataset_info.get("basic_stats", "") |
|
|
} |
|
|
|
|
|
return self._execute_chain( |
|
|
self.data_quality_chain, |
|
|
input_data, |
|
|
"Data quality insights" |
|
|
) |
|
|
|
|
|
def answer_dataset_question(self, question: str, dataset_info: Dict[str, Any]) -> str: |
|
|
""" |
|
|
Answer a specific question about the dataset using LangChain |
|
|
|
|
|
Args: |
|
|
question: User's question about the dataset |
|
|
dataset_info: Dictionary containing dataset analysis |
|
|
|
|
|
Returns: |
|
|
str: Answer to the user's question |
|
|
""" |
|
|
logger.info(f"Answering dataset question: {question[:50]}...") |
|
|
|
|
|
|
|
|
columns_info = self._format_columns_info( |
|
|
dataset_info.get("columns", []), |
|
|
dataset_info.get("dtypes", {}) |
|
|
) |
|
|
|
|
|
|
|
|
input_data = { |
|
|
"shape": dataset_info.get("shape", "N/A"), |
|
|
"columns_info": columns_info, |
|
|
"basic_stats": dataset_info.get("basic_stats", ""), |
|
|
"question": question |
|
|
} |
|
|
|
|
|
return self._execute_chain( |
|
|
self.qa_chain, |
|
|
input_data, |
|
|
"Answer" |
|
|
) |
|
|
|
|
|
def answer_with_memory(self, question: str, dataset_info: Dict[str, Any], memory) -> str: |
|
|
""" |
|
|
Answer a question with conversation memory to maintain context |
|
|
|
|
|
Args: |
|
|
question: User's question about the dataset |
|
|
dataset_info: Dictionary containing dataset analysis |
|
|
memory: ConversationBufferMemory instance to store conversation history |
|
|
|
|
|
Returns: |
|
|
str: Answer to the user's question with conversation context |
|
|
""" |
|
|
logger.info(f"Answering with memory: {question[:50]}...") |
|
|
|
|
|
|
|
|
columns_info = self._format_columns_info( |
|
|
dataset_info.get("columns", []), |
|
|
dataset_info.get("dtypes", {}) |
|
|
) |
|
|
|
|
|
|
|
|
memory_prompt = ChatPromptTemplate.from_messages([ |
|
|
HumanMessagePromptTemplate.from_template( |
|
|
"""You are a data scientist answering questions about a dataset. |
|
|
The following is information about the dataset: |
|
|
|
|
|
Dataset Information: |
|
|
- Shape: {shape} |
|
|
- Columns and their types: |
|
|
{columns_info} |
|
|
|
|
|
- Basic statistics: |
|
|
{basic_stats} |
|
|
|
|
|
Previous conversation: |
|
|
{chat_history} |
|
|
|
|
|
User's new question: {question} |
|
|
|
|
|
Please provide a clear, informative answer to the user's question. Take into account the previous conversation for context. Make your answer specific to the dataset information provided.""" |
|
|
) |
|
|
]) |
|
|
|
|
|
|
|
|
memory_chain = memory_prompt | self.llm |
|
|
|
|
|
|
|
|
try: |
|
|
chat_history = memory.load_memory_variables({})["chat_history"] |
|
|
|
|
|
chat_history_str = "\n".join([f"{msg.type}: {msg.content}" for msg in chat_history]) |
|
|
except Exception as e: |
|
|
logger.warning(f"Error loading memory: {str(e)}. Using empty chat history.") |
|
|
chat_history_str = "No previous conversation." |
|
|
|
|
|
input_data = { |
|
|
"shape": dataset_info.get("shape", "N/A"), |
|
|
"columns_info": columns_info, |
|
|
"basic_stats": dataset_info.get("basic_stats", ""), |
|
|
"question": question, |
|
|
"chat_history": chat_history_str |
|
|
} |
|
|
|
|
|
|
|
|
response = self._execute_chain( |
|
|
memory_chain, |
|
|
input_data, |
|
|
"Answer with memory" |
|
|
) |
|
|
|
|
|
|
|
|
memory.save_context( |
|
|
{"input": question}, |
|
|
{"output": response} |
|
|
) |
|
|
|
|
|
return response |
|
|
|
|
|
|