File size: 15,599 Bytes
7e960ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34f3676
7e960ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34f3676
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
"""
LLM Inference Module

This module handles all interactions with the Groq API via LangChain,
allowing the application to generate EDA insights and feature engineering
recommendations from dataset analysis.
"""

import os
from dotenv import load_dotenv
import logging
import time
from typing import Dict, Any, List, Optional
from langchain_community.callbacks.manager import get_openai_callback

# LangChain imports
from langchain_groq import ChatGroq
from langchain_core.messages import HumanMessage
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain_community.callbacks.manager import get_openai_callback
from langchain_core.runnables import RunnableSequence

# Configure logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

# Load environment variables
load_dotenv()
GROQ_API_KEY = os.getenv("GROQ_API_KEY")

if not GROQ_API_KEY:
    raise ValueError("GROQ_API_KEY not found in environment variables. Please add it to your .env file.")

# Create LLM model
try:
    llm = ChatGroq(model_name="llama3-8b-8192", groq_api_key=GROQ_API_KEY)
    logger.info("Successfully initialized Groq client")
except Exception as e:
    logger.error(f"Failed to initialize Groq client: {str(e)}")
    raise

class LLMInference:
    """Class for interacting with LLM via Groq API using LangChain"""
    
    def __init__(self, model_id: str = "llama3-8b-8192"):
        """Initialize the LLM inference class with Groq model"""
        self.model_id = model_id
        self.llm = llm
        
        # Initialize prompt templates and chains
        self._init_prompt_templates()
        self._init_chains()
        
        logger.info(f"LLMInference initialized with model: {model_id}")
    
    def _init_prompt_templates(self):
        """Initialize all prompt templates"""
        
        # EDA insights prompt template
        self.eda_prompt_template = ChatPromptTemplate.from_messages([
            HumanMessagePromptTemplate.from_template(
                """You are a data scientist tasked with performing Exploratory Data Analysis (EDA) on a dataset. 
Based on the following dataset information, provide comprehensive EDA insights:

Dataset Information:
- Shape: {shape}
- Columns and their types:
{columns_info}

- Missing values:
{missing_info}

- Basic statistics:
{basic_stats}

- Top correlations:
{correlations}

- Sample data:
{sample_data}

Please provide a detailed EDA analysis that includes:

1. Summary of the dataset (what it appears to be about, key features, etc.)
2. Distribution analysis of key variables
3. Relationship analysis between variables
4. Identification of patterns, outliers, or anomalies
5. Recommended visualizations that would be insightful
6. Initial hypotheses based on the data

Your analysis should be structured, thorough, and provide actionable insights for further investigation.
"""
            )
        ])
        
        # Feature engineering prompt template
        self.feature_engineering_prompt_template = ChatPromptTemplate.from_messages([
            HumanMessagePromptTemplate.from_template(
                """You are a machine learning engineer specializing in feature engineering. 
Based on the following dataset information, provide recommendations for feature engineering:

Dataset Information:
- Shape: {shape}
- Columns and their types:
{columns_info}

- Basic statistics:
{basic_stats}

- Top correlations:
{correlations}

Please provide comprehensive feature engineering recommendations that include:

1. Numerical feature transformations (scaling, normalization, log transforms, etc.)
2. Categorical feature encoding strategies
3. Feature interaction suggestions
4. Dimensionality reduction approaches if applicable
5. Time-based feature creation if applicable
6. Text processing techniques if there are text fields
7. Feature selection recommendations

For each recommendation, explain why it would be beneficial and how it could improve model performance.
Be specific to this dataset's characteristics rather than providing generic advice.
"""
            )
        ])
        
        # Data quality prompt template
        self.data_quality_prompt_template = ChatPromptTemplate.from_messages([
            HumanMessagePromptTemplate.from_template(
                """You are a data quality expert. 
Based on the following dataset information, provide data quality insights and recommendations:

Dataset Information:
- Shape: {shape}
- Columns and their types:
{columns_info}

- Missing values:
{missing_info}

- Basic statistics:
{basic_stats}

Please provide a comprehensive data quality assessment that includes:

1. Assessment of data completeness (missing values)
2. Identification of potential data inconsistencies or errors
3. Recommendations for data cleaning and preprocessing
4. Advice on handling outliers
5. Suggestions for data validation checks
6. Recommendations to improve data quality

Your assessment should be specific to this dataset and provide actionable recommendations.
"""
            )
        ])
        
        # QA prompt template
        self.qa_prompt_template = ChatPromptTemplate.from_messages([
            HumanMessagePromptTemplate.from_template(
                """You are a data scientist answering questions about a dataset. 
Based on the following dataset information, please answer the user's question:

Dataset Information:
- Shape: {shape}
- Columns and their types:
{columns_info}

- Basic statistics:
{basic_stats}

User's question: {question}

Please provide a clear, informative answer to the user's question based on the dataset information provided.
"""
            )
        ])
    
    def _init_chains(self):
        """Initialize all chains using modern RunnableSequence pattern"""
        
        # EDA insights chain
        self.eda_chain = self.eda_prompt_template | self.llm
        
        # Feature engineering chain
        self.feature_engineering_chain = self.feature_engineering_prompt_template | self.llm
        
        # Data quality chain
        self.data_quality_chain = self.data_quality_prompt_template | self.llm
        
        # QA chain
        self.qa_chain = self.qa_prompt_template | self.llm
    
    def _format_columns_info(self, columns: List[str], dtypes: Dict[str, str]) -> str:
        """Format columns info for prompt"""
        return "\n".join([f"- {col} ({dtypes.get(col, 'unknown')})" for col in columns])
    
    def _format_missing_info(self, missing_values: Dict[str, tuple]) -> str:
        """Format missing values info for prompt"""
        missing_info = "\n".join([f"- {col}: {count} missing values ({percent}%)" 
                               for col, (count, percent) in missing_values.items() if count > 0])
        
        if not missing_info:
            missing_info = "No missing values detected."
            
        return missing_info
    
    def _execute_chain(
        self, 
        chain: RunnableSequence, 
        input_data: Dict[str, Any], 
        operation_name: str
    ) -> str:
        """
        Execute a chain with tracking and error handling
        
        Args:
            chain: The LangChain chain to execute
            input_data: The input data for the chain
            operation_name: Name of the operation for logging
            
        Returns:
            str: The generated text
        """
        try:
            start_time = time.time()
            with get_openai_callback() as cb:
                result = chain.invoke(input_data).content
            elapsed_time = time.time() - start_time
            
            logger.info(f"{operation_name} generated in {elapsed_time:.2f} seconds")
            logger.info(f"Tokens used: {cb.total_tokens}, "
                      f"Prompt tokens: {cb.prompt_tokens}, "
                      f"Completion tokens: {cb.completion_tokens}")
            
            return result
        except Exception as e:
            error_msg = f"Error executing {operation_name.lower()}: {str(e)}"
            logger.error(error_msg)
            return error_msg
    
    def generate_eda_insights(self, dataset_info: Dict[str, Any]) -> str:
        """
        Generate EDA insights based on dataset information using LangChain
        
        Args:
            dataset_info: Dictionary containing dataset analysis
        
        Returns:
            str: Detailed EDA insights and recommendations
        """
        logger.info("Generating EDA insights")
        
        # Format the input data
        columns_info = self._format_columns_info(
            dataset_info.get("columns", []), 
            dataset_info.get("dtypes", {})
        )
        
        missing_info = self._format_missing_info(
            dataset_info.get("missing_values", {})
        )
        
        # Prepare input for the chain
        input_data = {
            "shape": dataset_info.get("shape", "N/A"),
            "columns_info": columns_info,
            "missing_info": missing_info,
            "basic_stats": dataset_info.get("basic_stats", ""),
            "correlations": dataset_info.get("correlations", ""),
            "sample_data": dataset_info.get("sample_data", "N/A")
        }
        
        return self._execute_chain(self.eda_chain, input_data, "EDA insights")
    
    def generate_feature_engineering_recommendations(self, dataset_info: Dict[str, Any]) -> str:
        """
        Generate feature engineering recommendations based on dataset information using LangChain
        
        Args:
            dataset_info: Dictionary containing dataset analysis
        
        Returns:
            str: Feature engineering recommendations
        """
        logger.info("Generating feature engineering recommendations")
        
        # Format the input data
        columns_info = self._format_columns_info(
            dataset_info.get("columns", []), 
            dataset_info.get("dtypes", {})
        )
        
        # Prepare input for the chain
        input_data = {
            "shape": dataset_info.get("shape", "N/A"),
            "columns_info": columns_info,
            "basic_stats": dataset_info.get("basic_stats", ""),
            "correlations": dataset_info.get("correlations", "")
        }
        
        return self._execute_chain(
            self.feature_engineering_chain, 
            input_data, 
            "Feature engineering recommendations"
        )
    
    def generate_data_quality_insights(self, dataset_info: Dict[str, Any]) -> str:
        """
        Generate data quality insights based on dataset information using LangChain
        
        Args:
            dataset_info: Dictionary containing dataset analysis
        
        Returns:
            str: Data quality insights and improvement recommendations
        """
        logger.info("Generating data quality insights")
        
        # Format the input data
        columns_info = self._format_columns_info(
            dataset_info.get("columns", []), 
            dataset_info.get("dtypes", {})
        )
        
        missing_info = self._format_missing_info(
            dataset_info.get("missing_values", {})
        )
        
        # Prepare input for the chain
        input_data = {
            "shape": dataset_info.get("shape", "N/A"),
            "columns_info": columns_info,
            "missing_info": missing_info,
            "basic_stats": dataset_info.get("basic_stats", "")
        }
        
        return self._execute_chain(
            self.data_quality_chain, 
            input_data, 
            "Data quality insights"
        )
    
    def answer_dataset_question(self, question: str, dataset_info: Dict[str, Any]) -> str:
        """
        Answer a specific question about the dataset using LangChain
        
        Args:
            question: User's question about the dataset
            dataset_info: Dictionary containing dataset analysis
        
        Returns:
            str: Answer to the user's question
        """
        logger.info(f"Answering dataset question: {question[:50]}...")
        
        # Format the input data
        columns_info = self._format_columns_info(
            dataset_info.get("columns", []), 
            dataset_info.get("dtypes", {})
        )
        
        # Prepare input for the chain
        input_data = {
            "shape": dataset_info.get("shape", "N/A"),
            "columns_info": columns_info,
            "basic_stats": dataset_info.get("basic_stats", ""),
            "question": question
        }
        
        return self._execute_chain(
            self.qa_chain, 
            input_data, 
            "Answer"
        )
        
    def answer_with_memory(self, question: str, dataset_info: Dict[str, Any], memory) -> str:
        """
        Answer a question with conversation memory to maintain context
        
        Args:
            question: User's question about the dataset
            dataset_info: Dictionary containing dataset analysis
            memory: ConversationBufferMemory instance to store conversation history
            
        Returns:
            str: Answer to the user's question with conversation context
        """
        logger.info(f"Answering with memory: {question[:50]}...")
        
        # Format the input data for the dataset context
        columns_info = self._format_columns_info(
            dataset_info.get("columns", []), 
            dataset_info.get("dtypes", {})
        )
        
        # Create a custom prompt that includes both conversation history and dataset info
        memory_prompt = ChatPromptTemplate.from_messages([
            HumanMessagePromptTemplate.from_template(
                """You are a data scientist answering questions about a dataset. 
The following is information about the dataset:

Dataset Information:
- Shape: {shape}
- Columns and their types:
{columns_info}

- Basic statistics:
{basic_stats}

Previous conversation:
{chat_history}

User's new question: {question}

Please provide a clear, informative answer to the user's question. Take into account the previous conversation for context. Make your answer specific to the dataset information provided."""
            )
        ])
        
        # Create a chain that uses both the prompt and memory
        memory_chain = memory_prompt | self.llm
        
        # Prepare the input data including memory retrieved from conversation_memory
        try:
            chat_history = memory.load_memory_variables({})["chat_history"]
            # Format chat history into a string
            chat_history_str = "\n".join([f"{msg.type}: {msg.content}" for msg in chat_history])
        except Exception as e:
            logger.warning(f"Error loading memory: {str(e)}. Using empty chat history.")
            chat_history_str = "No previous conversation."
            
        input_data = {
            "shape": dataset_info.get("shape", "N/A"),
            "columns_info": columns_info,
            "basic_stats": dataset_info.get("basic_stats", ""),
            "question": question,
            "chat_history": chat_history_str
        }
        
        # Execute the chain and get a response
        response = self._execute_chain(
            memory_chain, 
            input_data, 
            "Answer with memory"
        )
        
        # Save the interaction to memory
        memory.save_context(
            {"input": question},
            {"output": response}
        )
        
        return response