| | """ |
| | LLM Inference Module |
| | |
| | This module handles all interactions with the Groq API via LangChain, |
| | allowing the application to generate EDA insights and feature engineering |
| | recommendations from dataset analysis. |
| | """ |
| |
|
| | import os |
| | from dotenv import load_dotenv |
| | import logging |
| | import time |
| | from typing import Dict, Any, List, Optional |
| | from langchain_community.callbacks.manager import get_openai_callback |
| |
|
| | |
| | from langchain_groq import ChatGroq |
| | from langchain_core.messages import HumanMessage |
| | from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate |
| | |
| | from langchain_core.runnables import RunnableSequence |
| |
|
| | |
| | logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") |
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| | load_dotenv() |
| | GROQ_API_KEY = os.getenv("GROQ_API_KEY") |
| |
|
| | if not GROQ_API_KEY: |
| | raise ValueError("GROQ_API_KEY not found in environment variables. Please add it to your .env file.") |
| |
|
| | |
| | try: |
| | llm = ChatGroq(model_name="llama3-8b-8192", groq_api_key=GROQ_API_KEY) |
| | logger.info("Successfully initialized Groq client") |
| | except Exception as e: |
| | logger.error(f"Failed to initialize Groq client: {str(e)}") |
| | raise |
| |
|
| | class LLMInference: |
| | """Class for interacting with LLM via Groq API using LangChain""" |
| | |
| | def __init__(self, model_id: str = "llama3-8b-8192"): |
| | """Initialize the LLM inference class with Groq model""" |
| | self.model_id = model_id |
| | self.llm = llm |
| | |
| | |
| | self._init_prompt_templates() |
| | self._init_chains() |
| | |
| | logger.info(f"LLMInference initialized with model: {model_id}") |
| | |
| | def _init_prompt_templates(self): |
| | """Initialize all prompt templates""" |
| | |
| | |
| | self.eda_prompt_template = ChatPromptTemplate.from_messages([ |
| | HumanMessagePromptTemplate.from_template( |
| | """You are a data scientist tasked with performing Exploratory Data Analysis (EDA) on a dataset. |
| | Based on the following dataset information, provide comprehensive EDA insights: |
| | |
| | Dataset Information: |
| | - Shape: {shape} |
| | - Columns and their types: |
| | {columns_info} |
| | |
| | - Missing values: |
| | {missing_info} |
| | |
| | - Basic statistics: |
| | {basic_stats} |
| | |
| | - Top correlations: |
| | {correlations} |
| | |
| | - Sample data: |
| | {sample_data} |
| | |
| | Please provide a detailed EDA analysis that includes: |
| | |
| | 1. Summary of the dataset (what it appears to be about, key features, etc.) |
| | 2. Distribution analysis of key variables |
| | 3. Relationship analysis between variables |
| | 4. Identification of patterns, outliers, or anomalies |
| | 5. Recommended visualizations that would be insightful |
| | 6. Initial hypotheses based on the data |
| | |
| | Your analysis should be structured, thorough, and provide actionable insights for further investigation. |
| | """ |
| | ) |
| | ]) |
| | |
| | |
| | self.feature_engineering_prompt_template = ChatPromptTemplate.from_messages([ |
| | HumanMessagePromptTemplate.from_template( |
| | """You are a machine learning engineer specializing in feature engineering. |
| | Based on the following dataset information, provide recommendations for feature engineering: |
| | |
| | Dataset Information: |
| | - Shape: {shape} |
| | - Columns and their types: |
| | {columns_info} |
| | |
| | - Basic statistics: |
| | {basic_stats} |
| | |
| | - Top correlations: |
| | {correlations} |
| | |
| | Please provide comprehensive feature engineering recommendations that include: |
| | |
| | 1. Numerical feature transformations (scaling, normalization, log transforms, etc.) |
| | 2. Categorical feature encoding strategies |
| | 3. Feature interaction suggestions |
| | 4. Dimensionality reduction approaches if applicable |
| | 5. Time-based feature creation if applicable |
| | 6. Text processing techniques if there are text fields |
| | 7. Feature selection recommendations |
| | |
| | For each recommendation, explain why it would be beneficial and how it could improve model performance. |
| | Be specific to this dataset's characteristics rather than providing generic advice. |
| | """ |
| | ) |
| | ]) |
| | |
| | |
| | self.data_quality_prompt_template = ChatPromptTemplate.from_messages([ |
| | HumanMessagePromptTemplate.from_template( |
| | """You are a data quality expert. |
| | Based on the following dataset information, provide data quality insights and recommendations: |
| | |
| | Dataset Information: |
| | - Shape: {shape} |
| | - Columns and their types: |
| | {columns_info} |
| | |
| | - Missing values: |
| | {missing_info} |
| | |
| | - Basic statistics: |
| | {basic_stats} |
| | |
| | Please provide a comprehensive data quality assessment that includes: |
| | |
| | 1. Assessment of data completeness (missing values) |
| | 2. Identification of potential data inconsistencies or errors |
| | 3. Recommendations for data cleaning and preprocessing |
| | 4. Advice on handling outliers |
| | 5. Suggestions for data validation checks |
| | 6. Recommendations to improve data quality |
| | |
| | Your assessment should be specific to this dataset and provide actionable recommendations. |
| | """ |
| | ) |
| | ]) |
| | |
| | |
| | self.qa_prompt_template = ChatPromptTemplate.from_messages([ |
| | HumanMessagePromptTemplate.from_template( |
| | """You are a data scientist answering questions about a dataset. |
| | Based on the following dataset information, please answer the user's question: |
| | |
| | Dataset Information: |
| | - Shape: {shape} |
| | - Columns and their types: |
| | {columns_info} |
| | |
| | - Basic statistics: |
| | {basic_stats} |
| | |
| | User's question: {question} |
| | |
| | Please provide a clear, informative answer to the user's question based on the dataset information provided. |
| | """ |
| | ) |
| | ]) |
| | |
| | def _init_chains(self): |
| | """Initialize all chains using modern RunnableSequence pattern""" |
| | |
| | |
| | self.eda_chain = self.eda_prompt_template | self.llm |
| | |
| | |
| | self.feature_engineering_chain = self.feature_engineering_prompt_template | self.llm |
| | |
| | |
| | self.data_quality_chain = self.data_quality_prompt_template | self.llm |
| | |
| | |
| | self.qa_chain = self.qa_prompt_template | self.llm |
| | |
| | def _format_columns_info(self, columns: List[str], dtypes: Dict[str, str]) -> str: |
| | """Format columns info for prompt""" |
| | return "\n".join([f"- {col} ({dtypes.get(col, 'unknown')})" for col in columns]) |
| | |
| | def _format_missing_info(self, missing_values: Dict[str, tuple]) -> str: |
| | """Format missing values info for prompt""" |
| | missing_info = "\n".join([f"- {col}: {count} missing values ({percent}%)" |
| | for col, (count, percent) in missing_values.items() if count > 0]) |
| | |
| | if not missing_info: |
| | missing_info = "No missing values detected." |
| | |
| | return missing_info |
| | |
| | def _execute_chain( |
| | self, |
| | chain: RunnableSequence, |
| | input_data: Dict[str, Any], |
| | operation_name: str |
| | ) -> str: |
| | """ |
| | Execute a chain with tracking and error handling |
| | |
| | Args: |
| | chain: The LangChain chain to execute |
| | input_data: The input data for the chain |
| | operation_name: Name of the operation for logging |
| | |
| | Returns: |
| | str: The generated text |
| | """ |
| | try: |
| | start_time = time.time() |
| | with get_openai_callback() as cb: |
| | result = chain.invoke(input_data).content |
| | elapsed_time = time.time() - start_time |
| | |
| | logger.info(f"{operation_name} generated in {elapsed_time:.2f} seconds") |
| | logger.info(f"Tokens used: {cb.total_tokens}, " |
| | f"Prompt tokens: {cb.prompt_tokens}, " |
| | f"Completion tokens: {cb.completion_tokens}") |
| | |
| | return result |
| | except Exception as e: |
| | error_msg = f"Error executing {operation_name.lower()}: {str(e)}" |
| | logger.error(error_msg) |
| | return error_msg |
| | |
| | def generate_eda_insights(self, dataset_info: Dict[str, Any]) -> str: |
| | """ |
| | Generate EDA insights based on dataset information using LangChain |
| | |
| | Args: |
| | dataset_info: Dictionary containing dataset analysis |
| | |
| | Returns: |
| | str: Detailed EDA insights and recommendations |
| | """ |
| | logger.info("Generating EDA insights") |
| | |
| | |
| | columns_info = self._format_columns_info( |
| | dataset_info.get("columns", []), |
| | dataset_info.get("dtypes", {}) |
| | ) |
| | |
| | missing_info = self._format_missing_info( |
| | dataset_info.get("missing_values", {}) |
| | ) |
| | |
| | |
| | input_data = { |
| | "shape": dataset_info.get("shape", "N/A"), |
| | "columns_info": columns_info, |
| | "missing_info": missing_info, |
| | "basic_stats": dataset_info.get("basic_stats", ""), |
| | "correlations": dataset_info.get("correlations", ""), |
| | "sample_data": dataset_info.get("sample_data", "N/A") |
| | } |
| | |
| | return self._execute_chain(self.eda_chain, input_data, "EDA insights") |
| | |
| | def generate_feature_engineering_recommendations(self, dataset_info: Dict[str, Any]) -> str: |
| | """ |
| | Generate feature engineering recommendations based on dataset information using LangChain |
| | |
| | Args: |
| | dataset_info: Dictionary containing dataset analysis |
| | |
| | Returns: |
| | str: Feature engineering recommendations |
| | """ |
| | logger.info("Generating feature engineering recommendations") |
| | |
| | |
| | columns_info = self._format_columns_info( |
| | dataset_info.get("columns", []), |
| | dataset_info.get("dtypes", {}) |
| | ) |
| | |
| | |
| | input_data = { |
| | "shape": dataset_info.get("shape", "N/A"), |
| | "columns_info": columns_info, |
| | "basic_stats": dataset_info.get("basic_stats", ""), |
| | "correlations": dataset_info.get("correlations", "") |
| | } |
| | |
| | return self._execute_chain( |
| | self.feature_engineering_chain, |
| | input_data, |
| | "Feature engineering recommendations" |
| | ) |
| | |
| | def generate_data_quality_insights(self, dataset_info: Dict[str, Any]) -> str: |
| | """ |
| | Generate data quality insights based on dataset information using LangChain |
| | |
| | Args: |
| | dataset_info: Dictionary containing dataset analysis |
| | |
| | Returns: |
| | str: Data quality insights and improvement recommendations |
| | """ |
| | logger.info("Generating data quality insights") |
| | |
| | |
| | columns_info = self._format_columns_info( |
| | dataset_info.get("columns", []), |
| | dataset_info.get("dtypes", {}) |
| | ) |
| | |
| | missing_info = self._format_missing_info( |
| | dataset_info.get("missing_values", {}) |
| | ) |
| | |
| | |
| | input_data = { |
| | "shape": dataset_info.get("shape", "N/A"), |
| | "columns_info": columns_info, |
| | "missing_info": missing_info, |
| | "basic_stats": dataset_info.get("basic_stats", "") |
| | } |
| | |
| | return self._execute_chain( |
| | self.data_quality_chain, |
| | input_data, |
| | "Data quality insights" |
| | ) |
| | |
| | def answer_dataset_question(self, question: str, dataset_info: Dict[str, Any]) -> str: |
| | """ |
| | Answer a specific question about the dataset using LangChain |
| | |
| | Args: |
| | question: User's question about the dataset |
| | dataset_info: Dictionary containing dataset analysis |
| | |
| | Returns: |
| | str: Answer to the user's question |
| | """ |
| | logger.info(f"Answering dataset question: {question[:50]}...") |
| | |
| | |
| | columns_info = self._format_columns_info( |
| | dataset_info.get("columns", []), |
| | dataset_info.get("dtypes", {}) |
| | ) |
| | |
| | |
| | input_data = { |
| | "shape": dataset_info.get("shape", "N/A"), |
| | "columns_info": columns_info, |
| | "basic_stats": dataset_info.get("basic_stats", ""), |
| | "question": question |
| | } |
| | |
| | return self._execute_chain( |
| | self.qa_chain, |
| | input_data, |
| | "Answer" |
| | ) |
| |
|