mahmoudalrefaey commited on
Commit
c1138eb
·
verified ·
1 Parent(s): 40db345

Upload 2 files

Browse files
Files changed (2) hide show
  1. modules/llm_manager.py +283 -0
  2. modules/rag_pipeline.py +273 -0
modules/llm_manager.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLM Manager Module
3
+ Handles local language models using transformers and HuggingFace
4
+ """
5
+
6
+ import logging
7
+ import torch
8
+ from typing import Optional, Dict, Any
9
+ from transformers import (
10
+ AutoTokenizer,
11
+ AutoModelForCausalLM,
12
+ pipeline,
13
+ BitsAndBytesConfig
14
+ )
15
+ from langchain_community.llms import HuggingFacePipeline
16
+ from langchain.callbacks.manager import CallbackManager
17
+
18
+ class LLMManager:
19
+ """Manages local language models for text generation"""
20
+
21
+ def __init__(self, model_name: str = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"):
22
+ """
23
+ Initialize LLM manager
24
+
25
+ Args:
26
+ model_name: Name of the HuggingFace model to use
27
+ """
28
+ self.model_name = model_name
29
+ self.tokenizer = None
30
+ self.model = None
31
+ self.pipeline = None
32
+ self.llm = None
33
+
34
+ # Configure logging
35
+ logging.basicConfig(level=logging.INFO)
36
+ self.logger = logging.getLogger(__name__)
37
+
38
+ # Model configuration
39
+ self.model_config = {
40
+ "TinyLlama/TinyLlama-1.1B-Chat-v1.0": {
41
+ "max_length": 1024, # Reduced for speed
42
+ "temperature": 0.7,
43
+ "top_p": 0.95,
44
+ "do_sample": True,
45
+ "pad_token_id": 0,
46
+ "eos_token_id": 2
47
+ },
48
+ "microsoft/DialoGPT-medium": {
49
+ "max_length": 512, # Reduced for speed
50
+ "temperature": 0.7,
51
+ "top_p": 0.9,
52
+ "do_sample": True,
53
+ "pad_token_id": 50256,
54
+ "eos_token_id": 50256
55
+ },
56
+ "microsoft/phi-2": {
57
+ "max_length": 2048,
58
+ "temperature": 0.7,
59
+ "top_p": 0.95,
60
+ "do_sample": True,
61
+ "pad_token_id": 0,
62
+ "eos_token_id": 50256
63
+ }
64
+ }
65
+
66
+ # Initialize model
67
+ self._initialize_model()
68
+
69
+ def _initialize_model(self):
70
+ """Initialize the language model"""
71
+ try:
72
+ self.logger.info(f"Loading language model: {self.model_name}")
73
+
74
+ # Check if CUDA is available
75
+ device = "cuda" if torch.cuda.is_available() else "cpu"
76
+ self.logger.info(f"Using device: {device}")
77
+
78
+ # Load tokenizer
79
+ self.tokenizer = AutoTokenizer.from_pretrained(
80
+ self.model_name,
81
+ trust_remote_code=True
82
+ )
83
+
84
+ # Set padding token if not set
85
+ if self.tokenizer.pad_token is None:
86
+ self.tokenizer.pad_token = self.tokenizer.eos_token
87
+
88
+ # Load model with quantization for memory efficiency
89
+ if device == "cuda":
90
+ # Use 4-bit quantization for GPU
91
+ bnb_config = BitsAndBytesConfig(
92
+ load_in_4bit=True,
93
+ bnb_4bit_use_double_quant=True,
94
+ bnb_4bit_quant_type="nf4",
95
+ bnb_4bit_compute_dtype=torch.bfloat16
96
+ )
97
+
98
+ self.model = AutoModelForCausalLM.from_pretrained(
99
+ self.model_name,
100
+ quantization_config=bnb_config,
101
+ device_map="auto",
102
+ trust_remote_code=True,
103
+ torch_dtype=torch.bfloat16
104
+ )
105
+ else:
106
+ # Use CPU with 8-bit quantization
107
+ self.model = AutoModelForCausalLM.from_pretrained(
108
+ self.model_name,
109
+ device_map="cpu",
110
+ trust_remote_code=True,
111
+ torch_dtype=torch.float32,
112
+ low_cpu_mem_usage=True
113
+ )
114
+
115
+ # Get model configuration
116
+ config = self.model_config.get(self.model_name, self.model_config["TinyLlama/TinyLlama-1.1B-Chat-v1.0"])
117
+
118
+ # Create pipeline
119
+ self.pipeline = pipeline(
120
+ "text-generation",
121
+ model=self.model,
122
+ tokenizer=self.tokenizer,
123
+ max_length=config["max_length"],
124
+ temperature=config["temperature"],
125
+ top_p=config["top_p"],
126
+ do_sample=config["do_sample"],
127
+ pad_token_id=config["pad_token_id"],
128
+ eos_token_id=config["eos_token_id"],
129
+ return_full_text=False
130
+ )
131
+
132
+ # Create LangChain LLM wrapper
133
+ self.llm = HuggingFacePipeline(
134
+ pipeline=self.pipeline,
135
+ model_kwargs={"temperature": config["temperature"]}
136
+ )
137
+
138
+ self.logger.info("Language model loaded successfully")
139
+
140
+ except Exception as e:
141
+ self.logger.error(f"Error loading language model: {e}")
142
+ raise
143
+
144
+ def generate_response(self, prompt: str, max_tokens: int = 500, temperature: float = 0.7) -> str:
145
+ """
146
+ Generate response using the language model
147
+
148
+ Args:
149
+ prompt: Input prompt
150
+ max_tokens: Maximum number of tokens to generate
151
+ temperature: Sampling temperature
152
+
153
+ Returns:
154
+ Generated response
155
+ """
156
+ try:
157
+ if not self.llm:
158
+ raise ValueError("Language model not initialized")
159
+
160
+ self.logger.info(f"Generating response for prompt: {prompt[:50]}...")
161
+
162
+ # Format prompt based on model
163
+ formatted_prompt = self._format_prompt(prompt)
164
+
165
+ # Generate response
166
+ response = self.llm(
167
+ formatted_prompt,
168
+ max_new_tokens=max_tokens,
169
+ temperature=temperature,
170
+ do_sample=True
171
+ )
172
+
173
+ # Clean up response
174
+ cleaned_response = self._clean_response(response)
175
+
176
+ self.logger.info(f"Generated response: {cleaned_response[:50]}...")
177
+ return cleaned_response
178
+
179
+ except Exception as e:
180
+ self.logger.error(f"Error generating response: {e}")
181
+ raise
182
+
183
+ def _format_prompt(self, prompt: str) -> str:
184
+ """
185
+ Format prompt based on the model type
186
+
187
+ Args:
188
+ prompt: Raw prompt
189
+
190
+ Returns:
191
+ Formatted prompt
192
+ """
193
+ if "TinyLlama" in self.model_name:
194
+ # TinyLlama chat format
195
+ return f"<|system|>You are a helpful AI assistant. Answer questions based on the provided context.</s><|user|>{prompt}</s><|assistant|>"
196
+ elif "DialoGPT" in self.model_name:
197
+ # DialoGPT format
198
+ return f"User: {prompt}\nAssistant:"
199
+ elif "phi" in self.model_name:
200
+ # Phi format
201
+ return f"Instruct: {prompt}\nOutput:"
202
+ else:
203
+ # Default format
204
+ return prompt
205
+
206
+ def _clean_response(self, response: str) -> str:
207
+ """
208
+ Clean up the generated response
209
+
210
+ Args:
211
+ response: Raw response
212
+
213
+ Returns:
214
+ Cleaned response
215
+ """
216
+ # Remove prompt from response if present
217
+ if "Instruct:" in response:
218
+ response = response.split("Output:")[-1].strip()
219
+ elif "User:" in response:
220
+ response = response.split("Assistant:")[-1].strip()
221
+ elif "<|assistant|>" in response:
222
+ response = response.split("<|assistant|>")[-1].strip()
223
+
224
+ # Remove any remaining special tokens
225
+ response = response.replace("<|endoftext|>", "").replace("<|im_end|>", "").strip()
226
+
227
+ return response
228
+
229
+ def get_model_info(self) -> Dict[str, Any]:
230
+ """
231
+ Get information about the loaded model
232
+
233
+ Returns:
234
+ Dictionary with model information
235
+ """
236
+ if not self.model:
237
+ return {"status": "not_initialized"}
238
+
239
+ try:
240
+ # Get model parameters
241
+ total_params = sum(p.numel() for p in self.model.parameters())
242
+ trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
243
+
244
+ return {
245
+ "status": "initialized",
246
+ "model_name": self.model_name,
247
+ "total_parameters": f"{total_params:,}",
248
+ "trainable_parameters": f"{trainable_params:,}",
249
+ "device": next(self.model.parameters()).device,
250
+ "dtype": str(next(self.model.parameters()).dtype)
251
+ }
252
+
253
+ except Exception as e:
254
+ self.logger.error(f"Error getting model info: {e}")
255
+ return {"status": "error", "error": str(e)}
256
+
257
+ def change_model(self, model_name: str):
258
+ """
259
+ Change the language model
260
+
261
+ Args:
262
+ model_name: New model name
263
+ """
264
+ try:
265
+ self.logger.info(f"Changing model from {self.model_name} to {model_name}")
266
+
267
+ # Update model name
268
+ self.model_name = model_name
269
+
270
+ # Clear existing model
271
+ self.tokenizer = None
272
+ self.model = None
273
+ self.pipeline = None
274
+ self.llm = None
275
+
276
+ # Reinitialize with new model
277
+ self._initialize_model()
278
+
279
+ self.logger.info("Model changed successfully")
280
+
281
+ except Exception as e:
282
+ self.logger.error(f"Error changing model: {e}")
283
+ raise
modules/rag_pipeline.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RAG Pipeline Module
3
+ Orchestrates the retrieval-augmented generation process
4
+ """
5
+
6
+ import logging
7
+ from typing import List, Dict, Any, Optional
8
+ from langchain.schema import Document
9
+ from langchain.chains import RetrievalQA
10
+ from langchain.prompts import PromptTemplate
11
+ from langchain_community.vectorstores import FAISS
12
+
13
+ from .embedding_manager import EmbeddingManager
14
+ from .llm_manager import LLMManager
15
+
16
+ class RAGPipeline:
17
+ """Retrieval-Augmented Generation pipeline"""
18
+
19
+ def __init__(self, knowledge_base: FAISS, llm_manager: LLMManager):
20
+ """
21
+ Initialize RAG pipeline
22
+
23
+ Args:
24
+ knowledge_base: FAISS vector store
25
+ llm_manager: LLM manager instance
26
+ """
27
+ self.knowledge_base = knowledge_base
28
+ self.llm_manager = llm_manager
29
+ self.retrieval_chain = None
30
+
31
+ # Configure logging
32
+ logging.basicConfig(level=logging.INFO)
33
+ self.logger = logging.getLogger(__name__)
34
+
35
+ # Initialize retrieval chain
36
+ self._initialize_retrieval_chain()
37
+
38
+ def _initialize_retrieval_chain(self):
39
+ """Initialize the retrieval QA chain"""
40
+ try:
41
+ self.logger.info("Initializing retrieval QA chain")
42
+
43
+ # Create custom prompt template
44
+ prompt_template = """You are a helpful AI assistant that answers questions based on the provided context.
45
+
46
+ Context: {context}
47
+
48
+ Question: {question}
49
+
50
+ Please provide a comprehensive answer based on the context above. If the context doesn't contain enough information to answer the question, say so. Be accurate and helpful.
51
+
52
+ Answer:"""
53
+
54
+ prompt = PromptTemplate(
55
+ template=prompt_template,
56
+ input_variables=["context", "question"]
57
+ )
58
+
59
+ # Create retrieval QA chain
60
+ self.retrieval_chain = RetrievalQA.from_chain_type(
61
+ llm=self.llm_manager.llm,
62
+ chain_type="stuff",
63
+ retriever=self.knowledge_base.as_retriever(
64
+ search_type="similarity",
65
+ search_kwargs={"k": 2} # Reduced for speed
66
+ ),
67
+ chain_type_kwargs={"prompt": prompt},
68
+ return_source_documents=True
69
+ )
70
+
71
+ self.logger.info("Retrieval QA chain initialized successfully")
72
+
73
+ except Exception as e:
74
+ self.logger.error(f"Error initializing retrieval chain: {e}")
75
+ raise
76
+
77
+ def get_response(self, query: str, max_tokens: int = 500, temperature: float = 0.7) -> str:
78
+ """
79
+ Get response using RAG pipeline
80
+
81
+ Args:
82
+ query: User query
83
+ max_tokens: Maximum tokens for response
84
+ temperature: Sampling temperature
85
+
86
+ Returns:
87
+ Generated response
88
+ """
89
+ try:
90
+ if not self.retrieval_chain:
91
+ raise ValueError("Retrieval chain not initialized")
92
+
93
+ self.logger.info(f"Processing query: {query[:50]}...")
94
+
95
+ # Get relevant documents (reduced for speed)
96
+ relevant_docs = self.knowledge_base.similarity_search(query, k=2)
97
+
98
+ if not relevant_docs:
99
+ return "I couldn't find any relevant information in the provided documents to answer your question."
100
+
101
+ # Create context from relevant documents
102
+ context = self._create_context(relevant_docs)
103
+
104
+ # Generate response using LLM
105
+ response = self.llm_manager.generate_response(
106
+ prompt=self._create_prompt(query, context),
107
+ max_tokens=max_tokens,
108
+ temperature=temperature
109
+ )
110
+
111
+ self.logger.info(f"Generated response: {response[:50]}...")
112
+ return response
113
+
114
+ except Exception as e:
115
+ self.logger.error(f"Error in RAG pipeline: {e}")
116
+ return f"I encountered an error while processing your question: {str(e)}"
117
+
118
+ def _create_context(self, documents: List[Document]) -> str:
119
+ """
120
+ Create context string from relevant documents
121
+
122
+ Args:
123
+ documents: List of relevant documents
124
+
125
+ Returns:
126
+ Context string
127
+ """
128
+ context_parts = []
129
+
130
+ for i, doc in enumerate(documents, 1):
131
+ # Add document source if available
132
+ source = doc.metadata.get("source", "Unknown")
133
+ content = doc.page_content.strip()
134
+
135
+ context_parts.append(f"Document {i} (Source: {source}):\n{content}\n")
136
+
137
+ return "\n".join(context_parts)
138
+
139
+ def _create_prompt(self, query: str, context: str) -> str:
140
+ """
141
+ Create prompt for the LLM
142
+
143
+ Args:
144
+ query: User query
145
+ context: Retrieved context
146
+
147
+ Returns:
148
+ Formatted prompt
149
+ """
150
+ return f"""Based on the following context, please answer the user's question. If the context doesn't contain enough information to answer the question, say so.
151
+
152
+ Context:
153
+ {context}
154
+
155
+ Question: {query}
156
+
157
+ Answer:"""
158
+
159
+ def get_similar_documents(self, query: str, k: int = 4) -> List[Document]:
160
+ """
161
+ Get similar documents for a query
162
+
163
+ Args:
164
+ query: Search query
165
+ k: Number of documents to retrieve
166
+
167
+ Returns:
168
+ List of similar documents
169
+ """
170
+ try:
171
+ return self.knowledge_base.similarity_search(query, k=k)
172
+ except Exception as e:
173
+ self.logger.error(f"Error retrieving similar documents: {e}")
174
+ return []
175
+
176
+ def get_similar_documents_with_scores(self, query: str, k: int = 4) -> List[tuple]:
177
+ """
178
+ Get similar documents with similarity scores
179
+
180
+ Args:
181
+ query: Search query
182
+ k: Number of documents to retrieve
183
+
184
+ Returns:
185
+ List of (document, score) tuples
186
+ """
187
+ try:
188
+ return self.knowledge_base.similarity_search_with_score(query, k=k)
189
+ except Exception as e:
190
+ self.logger.error(f"Error retrieving similar documents with scores: {e}")
191
+ return []
192
+
193
+ def add_documents(self, documents: List[Document]):
194
+ """
195
+ Add new documents to the knowledge base
196
+
197
+ Args:
198
+ documents: List of documents to add
199
+ """
200
+ try:
201
+ if not documents:
202
+ return
203
+
204
+ self.logger.info(f"Adding {len(documents)} documents to knowledge base")
205
+
206
+ # Add documents to vector store
207
+ self.knowledge_base.add_documents(documents)
208
+
209
+ # Reinitialize retrieval chain with updated knowledge base
210
+ self._initialize_retrieval_chain()
211
+
212
+ self.logger.info("Documents added successfully")
213
+
214
+ except Exception as e:
215
+ self.logger.error(f"Error adding documents: {e}")
216
+ raise
217
+
218
+ def get_pipeline_info(self) -> Dict[str, Any]:
219
+ """
220
+ Get information about the RAG pipeline
221
+
222
+ Returns:
223
+ Dictionary with pipeline information
224
+ """
225
+ try:
226
+ # Get knowledge base info
227
+ kb_info = {}
228
+ if self.knowledge_base:
229
+ index = self.knowledge_base.index
230
+ kb_info = {
231
+ "documents": index.ntotal if hasattr(index, 'ntotal') else "unknown",
232
+ "index_type": type(index).__name__
233
+ }
234
+
235
+ # Get LLM info
236
+ llm_info = self.llm_manager.get_model_info()
237
+
238
+ return {
239
+ "status": "initialized" if self.retrieval_chain else "not_initialized",
240
+ "knowledge_base": kb_info,
241
+ "language_model": llm_info,
242
+ "retrieval_chain": "initialized" if self.retrieval_chain else "not_initialized"
243
+ }
244
+
245
+ except Exception as e:
246
+ self.logger.error(f"Error getting pipeline info: {e}")
247
+ return {"status": "error", "error": str(e)}
248
+
249
+ def update_retrieval_parameters(self, k: int = 4, search_type: str = "similarity"):
250
+ """
251
+ Update retrieval parameters
252
+
253
+ Args:
254
+ k: Number of documents to retrieve
255
+ search_type: Type of search (similarity, mmr, etc.)
256
+ """
257
+ try:
258
+ self.logger.info(f"Updating retrieval parameters: k={k}, search_type={search_type}")
259
+
260
+ # Update retriever
261
+ self.knowledge_base.as_retriever(
262
+ search_type=search_type,
263
+ search_kwargs={"k": k}
264
+ )
265
+
266
+ # Reinitialize chain
267
+ self._initialize_retrieval_chain()
268
+
269
+ self.logger.info("Retrieval parameters updated successfully")
270
+
271
+ except Exception as e:
272
+ self.logger.error(f"Error updating retrieval parameters: {e}")
273
+ raise