vashu2425 commited on
Commit
34f3676
·
verified ·
1 Parent(s): adecbfd

Update llm_inference.py

Browse files
Files changed (1) hide show
  1. llm_inference.py +477 -1
llm_inference.py CHANGED
@@ -1,3 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
  LLM Inference Module
3
 
@@ -17,7 +414,7 @@ from langchain_community.callbacks.manager import get_openai_callback
17
  from langchain_groq import ChatGroq
18
  from langchain_core.messages import HumanMessage
19
  from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
20
- # from langchain_community.callbacks.manager import get_openai_callbatck
21
  from langchain_core.runnables import RunnableSequence
22
 
23
  # Configure logging
@@ -375,3 +772,82 @@ Please provide a clear, informative answer to the user's question based on the d
375
  input_data,
376
  "Answer"
377
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # """
2
+ # LLM Inference Module
3
+
4
+ # This module handles all interactions with the Groq API via LangChain,
5
+ # allowing the application to generate EDA insights and feature engineering
6
+ # recommendations from dataset analysis.
7
+ # """
8
+
9
+ # import os
10
+ # from dotenv import load_dotenv
11
+ # import logging
12
+ # import time
13
+ # from typing import Dict, Any, List, Optional
14
+ # from langchain_community.callbacks.manager import get_openai_callback
15
+
16
+ # # LangChain imports
17
+ # from langchain_groq import ChatGroq
18
+ # from langchain_core.messages import HumanMessage
19
+ # from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
20
+ # # from langchain_community.callbacks.manager import get_openai_callbatck
21
+ # from langchain_core.runnables import RunnableSequence
22
+
23
+ # # Configure logging
24
+ # logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
25
+ # logger = logging.getLogger(__name__)
26
+
27
+ # # Load environment variables
28
+ # load_dotenv()
29
+ # GROQ_API_KEY = os.getenv("GROQ_API_KEY")
30
+
31
+ # if not GROQ_API_KEY:
32
+ # raise ValueError("GROQ_API_KEY not found in environment variables. Please add it to your .env file.")
33
+
34
+ # # Create LLM model
35
+ # try:
36
+ # llm = ChatGroq(model_name="llama3-8b-8192", groq_api_key=GROQ_API_KEY)
37
+ # logger.info("Successfully initialized Groq client")
38
+ # except Exception as e:
39
+ # logger.error(f"Failed to initialize Groq client: {str(e)}")
40
+ # raise
41
+
42
+ # class LLMInference:
43
+ # """Class for interacting with LLM via Groq API using LangChain"""
44
+
45
+ # def __init__(self, model_id: str = "llama3-8b-8192"):
46
+ # """Initialize the LLM inference class with Groq model"""
47
+ # self.model_id = model_id
48
+ # self.llm = llm
49
+
50
+ # # Initialize prompt templates and chains
51
+ # self._init_prompt_templates()
52
+ # self._init_chains()
53
+
54
+ # logger.info(f"LLMInference initialized with model: {model_id}")
55
+
56
+ # def _init_prompt_templates(self):
57
+ # """Initialize all prompt templates"""
58
+
59
+ # # EDA insights prompt template
60
+ # self.eda_prompt_template = ChatPromptTemplate.from_messages([
61
+ # HumanMessagePromptTemplate.from_template(
62
+ # """You are a data scientist tasked with performing Exploratory Data Analysis (EDA) on a dataset.
63
+ # Based on the following dataset information, provide comprehensive EDA insights:
64
+
65
+ # Dataset Information:
66
+ # - Shape: {shape}
67
+ # - Columns and their types:
68
+ # {columns_info}
69
+
70
+ # - Missing values:
71
+ # {missing_info}
72
+
73
+ # - Basic statistics:
74
+ # {basic_stats}
75
+
76
+ # - Top correlations:
77
+ # {correlations}
78
+
79
+ # - Sample data:
80
+ # {sample_data}
81
+
82
+ # Please provide a detailed EDA analysis that includes:
83
+
84
+ # 1. Summary of the dataset (what it appears to be about, key features, etc.)
85
+ # 2. Distribution analysis of key variables
86
+ # 3. Relationship analysis between variables
87
+ # 4. Identification of patterns, outliers, or anomalies
88
+ # 5. Recommended visualizations that would be insightful
89
+ # 6. Initial hypotheses based on the data
90
+
91
+ # Your analysis should be structured, thorough, and provide actionable insights for further investigation.
92
+ # """
93
+ # )
94
+ # ])
95
+
96
+ # # Feature engineering prompt template
97
+ # self.feature_engineering_prompt_template = ChatPromptTemplate.from_messages([
98
+ # HumanMessagePromptTemplate.from_template(
99
+ # """You are a machine learning engineer specializing in feature engineering.
100
+ # Based on the following dataset information, provide recommendations for feature engineering:
101
+
102
+ # Dataset Information:
103
+ # - Shape: {shape}
104
+ # - Columns and their types:
105
+ # {columns_info}
106
+
107
+ # - Basic statistics:
108
+ # {basic_stats}
109
+
110
+ # - Top correlations:
111
+ # {correlations}
112
+
113
+ # Please provide comprehensive feature engineering recommendations that include:
114
+
115
+ # 1. Numerical feature transformations (scaling, normalization, log transforms, etc.)
116
+ # 2. Categorical feature encoding strategies
117
+ # 3. Feature interaction suggestions
118
+ # 4. Dimensionality reduction approaches if applicable
119
+ # 5. Time-based feature creation if applicable
120
+ # 6. Text processing techniques if there are text fields
121
+ # 7. Feature selection recommendations
122
+
123
+ # For each recommendation, explain why it would be beneficial and how it could improve model performance.
124
+ # Be specific to this dataset's characteristics rather than providing generic advice.
125
+ # """
126
+ # )
127
+ # ])
128
+
129
+ # # Data quality prompt template
130
+ # self.data_quality_prompt_template = ChatPromptTemplate.from_messages([
131
+ # HumanMessagePromptTemplate.from_template(
132
+ # """You are a data quality expert.
133
+ # Based on the following dataset information, provide data quality insights and recommendations:
134
+
135
+ # Dataset Information:
136
+ # - Shape: {shape}
137
+ # - Columns and their types:
138
+ # {columns_info}
139
+
140
+ # - Missing values:
141
+ # {missing_info}
142
+
143
+ # - Basic statistics:
144
+ # {basic_stats}
145
+
146
+ # Please provide a comprehensive data quality assessment that includes:
147
+
148
+ # 1. Assessment of data completeness (missing values)
149
+ # 2. Identification of potential data inconsistencies or errors
150
+ # 3. Recommendations for data cleaning and preprocessing
151
+ # 4. Advice on handling outliers
152
+ # 5. Suggestions for data validation checks
153
+ # 6. Recommendations to improve data quality
154
+
155
+ # Your assessment should be specific to this dataset and provide actionable recommendations.
156
+ # """
157
+ # )
158
+ # ])
159
+
160
+ # # QA prompt template
161
+ # self.qa_prompt_template = ChatPromptTemplate.from_messages([
162
+ # HumanMessagePromptTemplate.from_template(
163
+ # """You are a data scientist answering questions about a dataset.
164
+ # Based on the following dataset information, please answer the user's question:
165
+
166
+ # Dataset Information:
167
+ # - Shape: {shape}
168
+ # - Columns and their types:
169
+ # {columns_info}
170
+
171
+ # - Basic statistics:
172
+ # {basic_stats}
173
+
174
+ # User's question: {question}
175
+
176
+ # Please provide a clear, informative answer to the user's question based on the dataset information provided.
177
+ # """
178
+ # )
179
+ # ])
180
+
181
+ # def _init_chains(self):
182
+ # """Initialize all chains using modern RunnableSequence pattern"""
183
+
184
+ # # EDA insights chain
185
+ # self.eda_chain = self.eda_prompt_template | self.llm
186
+
187
+ # # Feature engineering chain
188
+ # self.feature_engineering_chain = self.feature_engineering_prompt_template | self.llm
189
+
190
+ # # Data quality chain
191
+ # self.data_quality_chain = self.data_quality_prompt_template | self.llm
192
+
193
+ # # QA chain
194
+ # self.qa_chain = self.qa_prompt_template | self.llm
195
+
196
+ # def _format_columns_info(self, columns: List[str], dtypes: Dict[str, str]) -> str:
197
+ # """Format columns info for prompt"""
198
+ # return "\n".join([f"- {col} ({dtypes.get(col, 'unknown')})" for col in columns])
199
+
200
+ # def _format_missing_info(self, missing_values: Dict[str, tuple]) -> str:
201
+ # """Format missing values info for prompt"""
202
+ # missing_info = "\n".join([f"- {col}: {count} missing values ({percent}%)"
203
+ # for col, (count, percent) in missing_values.items() if count > 0])
204
+
205
+ # if not missing_info:
206
+ # missing_info = "No missing values detected."
207
+
208
+ # return missing_info
209
+
210
+ # def _execute_chain(
211
+ # self,
212
+ # chain: RunnableSequence,
213
+ # input_data: Dict[str, Any],
214
+ # operation_name: str
215
+ # ) -> str:
216
+ # """
217
+ # Execute a chain with tracking and error handling
218
+
219
+ # Args:
220
+ # chain: The LangChain chain to execute
221
+ # input_data: The input data for the chain
222
+ # operation_name: Name of the operation for logging
223
+
224
+ # Returns:
225
+ # str: The generated text
226
+ # """
227
+ # try:
228
+ # start_time = time.time()
229
+ # with get_openai_callback() as cb:
230
+ # result = chain.invoke(input_data).content
231
+ # elapsed_time = time.time() - start_time
232
+
233
+ # logger.info(f"{operation_name} generated in {elapsed_time:.2f} seconds")
234
+ # logger.info(f"Tokens used: {cb.total_tokens}, "
235
+ # f"Prompt tokens: {cb.prompt_tokens}, "
236
+ # f"Completion tokens: {cb.completion_tokens}")
237
+
238
+ # return result
239
+ # except Exception as e:
240
+ # error_msg = f"Error executing {operation_name.lower()}: {str(e)}"
241
+ # logger.error(error_msg)
242
+ # return error_msg
243
+
244
+ # def generate_eda_insights(self, dataset_info: Dict[str, Any]) -> str:
245
+ # """
246
+ # Generate EDA insights based on dataset information using LangChain
247
+
248
+ # Args:
249
+ # dataset_info: Dictionary containing dataset analysis
250
+
251
+ # Returns:
252
+ # str: Detailed EDA insights and recommendations
253
+ # """
254
+ # logger.info("Generating EDA insights")
255
+
256
+ # # Format the input data
257
+ # columns_info = self._format_columns_info(
258
+ # dataset_info.get("columns", []),
259
+ # dataset_info.get("dtypes", {})
260
+ # )
261
+
262
+ # missing_info = self._format_missing_info(
263
+ # dataset_info.get("missing_values", {})
264
+ # )
265
+
266
+ # # Prepare input for the chain
267
+ # input_data = {
268
+ # "shape": dataset_info.get("shape", "N/A"),
269
+ # "columns_info": columns_info,
270
+ # "missing_info": missing_info,
271
+ # "basic_stats": dataset_info.get("basic_stats", ""),
272
+ # "correlations": dataset_info.get("correlations", ""),
273
+ # "sample_data": dataset_info.get("sample_data", "N/A")
274
+ # }
275
+
276
+ # return self._execute_chain(self.eda_chain, input_data, "EDA insights")
277
+
278
+ # def generate_feature_engineering_recommendations(self, dataset_info: Dict[str, Any]) -> str:
279
+ # """
280
+ # Generate feature engineering recommendations based on dataset information using LangChain
281
+
282
+ # Args:
283
+ # dataset_info: Dictionary containing dataset analysis
284
+
285
+ # Returns:
286
+ # str: Feature engineering recommendations
287
+ # """
288
+ # logger.info("Generating feature engineering recommendations")
289
+
290
+ # # Format the input data
291
+ # columns_info = self._format_columns_info(
292
+ # dataset_info.get("columns", []),
293
+ # dataset_info.get("dtypes", {})
294
+ # )
295
+
296
+ # # Prepare input for the chain
297
+ # input_data = {
298
+ # "shape": dataset_info.get("shape", "N/A"),
299
+ # "columns_info": columns_info,
300
+ # "basic_stats": dataset_info.get("basic_stats", ""),
301
+ # "correlations": dataset_info.get("correlations", "")
302
+ # }
303
+
304
+ # return self._execute_chain(
305
+ # self.feature_engineering_chain,
306
+ # input_data,
307
+ # "Feature engineering recommendations"
308
+ # )
309
+
310
+ # def generate_data_quality_insights(self, dataset_info: Dict[str, Any]) -> str:
311
+ # """
312
+ # Generate data quality insights based on dataset information using LangChain
313
+
314
+ # Args:
315
+ # dataset_info: Dictionary containing dataset analysis
316
+
317
+ # Returns:
318
+ # str: Data quality insights and improvement recommendations
319
+ # """
320
+ # logger.info("Generating data quality insights")
321
+
322
+ # # Format the input data
323
+ # columns_info = self._format_columns_info(
324
+ # dataset_info.get("columns", []),
325
+ # dataset_info.get("dtypes", {})
326
+ # )
327
+
328
+ # missing_info = self._format_missing_info(
329
+ # dataset_info.get("missing_values", {})
330
+ # )
331
+
332
+ # # Prepare input for the chain
333
+ # input_data = {
334
+ # "shape": dataset_info.get("shape", "N/A"),
335
+ # "columns_info": columns_info,
336
+ # "missing_info": missing_info,
337
+ # "basic_stats": dataset_info.get("basic_stats", "")
338
+ # }
339
+
340
+ # return self._execute_chain(
341
+ # self.data_quality_chain,
342
+ # input_data,
343
+ # "Data quality insights"
344
+ # )
345
+
346
+ # def answer_dataset_question(self, question: str, dataset_info: Dict[str, Any]) -> str:
347
+ # """
348
+ # Answer a specific question about the dataset using LangChain
349
+
350
+ # Args:
351
+ # question: User's question about the dataset
352
+ # dataset_info: Dictionary containing dataset analysis
353
+
354
+ # Returns:
355
+ # str: Answer to the user's question
356
+ # """
357
+ # logger.info(f"Answering dataset question: {question[:50]}...")
358
+
359
+ # # Format the input data
360
+ # columns_info = self._format_columns_info(
361
+ # dataset_info.get("columns", []),
362
+ # dataset_info.get("dtypes", {})
363
+ # )
364
+
365
+ # # Prepare input for the chain
366
+ # input_data = {
367
+ # "shape": dataset_info.get("shape", "N/A"),
368
+ # "columns_info": columns_info,
369
+ # "basic_stats": dataset_info.get("basic_stats", ""),
370
+ # "question": question
371
+ # }
372
+
373
+ # return self._execute_chain(
374
+ # self.qa_chain,
375
+ # input_data,
376
+ # "Answer"
377
+ # )
378
+
379
+
380
+
381
+
382
+
383
+
384
+
385
+
386
+
387
+
388
+
389
+
390
+
391
+
392
+
393
+
394
+
395
+
396
+
397
+
398
  """
399
  LLM Inference Module
400
 
 
414
  from langchain_groq import ChatGroq
415
  from langchain_core.messages import HumanMessage
416
  from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
417
+ from langchain_community.callbacks.manager import get_openai_callback
418
  from langchain_core.runnables import RunnableSequence
419
 
420
  # Configure logging
 
772
  input_data,
773
  "Answer"
774
  )
775
+
776
+ def answer_with_memory(self, question: str, dataset_info: Dict[str, Any], memory) -> str:
777
+ """
778
+ Answer a question with conversation memory to maintain context
779
+
780
+ Args:
781
+ question: User's question about the dataset
782
+ dataset_info: Dictionary containing dataset analysis
783
+ memory: ConversationBufferMemory instance to store conversation history
784
+
785
+ Returns:
786
+ str: Answer to the user's question with conversation context
787
+ """
788
+ logger.info(f"Answering with memory: {question[:50]}...")
789
+
790
+ # Format the input data for the dataset context
791
+ columns_info = self._format_columns_info(
792
+ dataset_info.get("columns", []),
793
+ dataset_info.get("dtypes", {})
794
+ )
795
+
796
+ # Create a custom prompt that includes both conversation history and dataset info
797
+ memory_prompt = ChatPromptTemplate.from_messages([
798
+ HumanMessagePromptTemplate.from_template(
799
+ """You are a data scientist answering questions about a dataset.
800
+ The following is information about the dataset:
801
+
802
+ Dataset Information:
803
+ - Shape: {shape}
804
+ - Columns and their types:
805
+ {columns_info}
806
+
807
+ - Basic statistics:
808
+ {basic_stats}
809
+
810
+ Previous conversation:
811
+ {chat_history}
812
+
813
+ User's new question: {question}
814
+
815
+ Please provide a clear, informative answer to the user's question. Take into account the previous conversation for context. Make your answer specific to the dataset information provided."""
816
+ )
817
+ ])
818
+
819
+ # Create a chain that uses both the prompt and memory
820
+ memory_chain = memory_prompt | self.llm
821
+
822
+ # Prepare the input data including memory retrieved from conversation_memory
823
+ try:
824
+ chat_history = memory.load_memory_variables({})["chat_history"]
825
+ # Format chat history into a string
826
+ chat_history_str = "\n".join([f"{msg.type}: {msg.content}" for msg in chat_history])
827
+ except Exception as e:
828
+ logger.warning(f"Error loading memory: {str(e)}. Using empty chat history.")
829
+ chat_history_str = "No previous conversation."
830
+
831
+ input_data = {
832
+ "shape": dataset_info.get("shape", "N/A"),
833
+ "columns_info": columns_info,
834
+ "basic_stats": dataset_info.get("basic_stats", ""),
835
+ "question": question,
836
+ "chat_history": chat_history_str
837
+ }
838
+
839
+ # Execute the chain and get a response
840
+ response = self._execute_chain(
841
+ memory_chain,
842
+ input_data,
843
+ "Answer with memory"
844
+ )
845
+
846
+ # Save the interaction to memory
847
+ memory.save_context(
848
+ {"input": question},
849
+ {"output": response}
850
+ )
851
+
852
+ return response
853
+