File size: 16,963 Bytes
c59d808
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
# LLM Service - RAG pipeline using ConversationalRetrievalChain
from typing import List, Dict, Any, Optional

# Local imports
from backend.config.settings import settings
from backend.config.logging_config import get_logger
from backend.services.vector_store import vector_store_service

# Setup logging
logger = get_logger("llm_service")

class LLMService:
    """LLM service using ConversationalRetrievalChain for RAG pipeline"""
    
    def __init__(self):
        logger.info("πŸ€– Initializing LLM Service...")
        
        try:
            self.llm = self._setup_llm()
            self.retriever = self._setup_retriever()
            self.memory = self._setup_memory()
            self.qa_chain = self._setup_qa_chain()
            
            logger.info("πŸš€ LLM Service initialized successfully")
            
        except Exception as e:
            logger.error(f"❌ LLM Service initialization failed: {str(e)}", exc_info=True)
            raise
    
    def _setup_llm(self):
        """Setup LLM based on configuration with conditional imports"""
        llm_config = settings.get_llm_config()
        provider = llm_config["provider"]
        
        logger.info(f"πŸ”§ Setting up LLM provider: {provider}")
        
        if provider == "openai":
            try:
                from langchain_openai import ChatOpenAI
                logger.info("βœ… OpenAI LLM imported successfully")
                
                # Handle special cases for temperature restrictions
                temperature = llm_config["temperature"]
                model = llm_config["model"]
                max_tokens = llm_config.get("max_tokens", 1000)
                
                # GPT-5-nano has temperature restrictions (defaults to 1.0)
                if "gpt-5-nano" in model.lower():
                    temperature = 1.0
                    logger.info(f"πŸ”§ Using temperature=1.0 for {model} (model restriction)")
                
                # Log token configuration
                logger.info(f"🎯 OpenAI config - Model: {model}, Output tokens: {max_tokens}, Temperature: {temperature}")
                
                return ChatOpenAI(
                    api_key=llm_config["api_key"],
                    model=model,
                    temperature=temperature,
                    max_tokens=max_tokens  # This limits OUTPUT tokens only
                )
            except ImportError as e:
                logger.error(f"❌ OpenAI LLM not available: {e}")
                raise ImportError("OpenAI provider selected but langchain_openai not installed")
        
        elif provider == "google":
            try:
                from langchain_google_genai import ChatGoogleGenerativeAI
                logger.info("βœ… Google LLM imported successfully")
                
                max_output_tokens = llm_config.get("max_tokens", 1000)
                model = llm_config["model"]
                temperature = llm_config["temperature"]
                
                # Log token configuration
                logger.info(f"🎯 Google config - Model: {model}, Output tokens: {max_output_tokens}, Temperature: {temperature}")
                
                return ChatGoogleGenerativeAI(
                    google_api_key=llm_config["api_key"],
                    model=model,
                    temperature=temperature,
                    max_output_tokens=max_output_tokens  # This limits OUTPUT tokens only
                )
            except ImportError as e:
                logger.error(f"❌ Google LLM not available: {e}")
                raise ImportError("Google provider selected but langchain_google_genai not installed")
        
        elif provider == "ollama":
            try:
                from langchain_community.llms import Ollama
                logger.info("βœ… Ollama LLM imported successfully")
                return Ollama(
                    base_url=llm_config["base_url"],
                    model=llm_config["model"],
                    temperature=llm_config["temperature"]
                )
            except ImportError as e:
                logger.error(f"❌ Ollama LLM not available: {e}")
                raise ImportError("Ollama provider selected but langchain_community not installed")
        
        elif provider == "huggingface":
            try:
                # Check if we should use API or local pipeline
                use_api = llm_config.get("use_api", False)
                
                if use_api:
                    # Use HuggingFace Inference API with better error handling
                    try:
                        from langchain_huggingface import HuggingFaceEndpoint
                        logger.info("βœ… Using HuggingFace API (no local download)")
                        
                        return HuggingFaceEndpoint(
                            repo_id=llm_config["model"],
                            huggingfacehub_api_token=llm_config["api_token"],
                            temperature=0.7,  # HuggingFace API doesn't support dynamic temperature from config
                            max_new_tokens=200,
                            repetition_penalty=1.1,
                            top_p=0.9
                        )
                    except Exception as api_error:
                        logger.warning(f"⚠️ HuggingFace API failed: {api_error}")
                        logger.info("πŸ”„ Falling back to HuggingFace Hub API...")
                        
                        # Fallback to HuggingFaceHub (older but more reliable)
                        try:
                            from langchain_community.llms import HuggingFaceHub
                            
                            return HuggingFaceHub(
                                repo_id=llm_config["model"],
                                huggingfacehub_api_token=llm_config["api_token"],
                                model_kwargs={
                                    "temperature": 0.7,  # HuggingFace Hub API has limited temperature control
                                    "max_new_tokens": 200,
                                    "repetition_penalty": 1.1,
                                    "top_p": 0.9,
                                    "do_sample": True
                                }
                            )
                        except Exception as hub_error:
                            logger.error(f"❌ HuggingFace Hub also failed: {hub_error}")
                            raise ImportError(f"Both HuggingFace API methods failed: {api_error}, {hub_error}")
                else:
                    # Use local pipeline (downloads model)
                    from langchain_huggingface import HuggingFacePipeline
                    from transformers import pipeline
                    
                    logger.info("βœ… Using HuggingFace local pipeline")
                    
                    # Create HuggingFace pipeline - avoid device_map for CPU-only setups
                    pipeline_kwargs = {
                        "task": "text-generation",
                        "model": llm_config["model"],
                        "max_length": 512,  # Increase max length
                        "do_sample": True,  # Enable sampling for better responses
                        "temperature": 0.7,  # Local pipeline uses default 0.7 for stability
                        "pad_token_id": 50256,  # Set pad token to avoid warnings
                        "eos_token_id": 50256,  # Set end of sequence token
                    }
                
                    # Only add device_map if using GPU
                    if llm_config.get("use_gpu", False):
                        pipeline_kwargs["device_map"] = "auto"
                    else:
                        # For CPU, use device=0 which maps to CPU
                        pipeline_kwargs["device"] = "cpu"
                    
                    hf_pipeline = pipeline(**pipeline_kwargs)
                    
                    return HuggingFacePipeline(
                        pipeline=hf_pipeline,
                        model_kwargs={
                            "temperature": 0.7,  # Local pipeline temperature (limited configurability)
                            "max_new_tokens": 150,  # Reduced for efficiency
                            "do_sample": True,
                            "top_p": 0.9,
                            "repetition_penalty": 1.1,
                            "early_stopping": True,
                            "num_beams": 4  # Better quality for instruction following
                        }
                    )
            except ImportError as e:
                logger.error(f"❌ HuggingFace LLM not available: {e}")
                raise ImportError("HuggingFace provider selected but required packages not installed")
        
        else:
            logger.warning(f"⚠️ Unknown LLM provider '{provider}', falling back to OpenAI")
            try:
                from langchain_openai import ChatOpenAI
                return ChatOpenAI()
            except ImportError:
                logger.error("❌ No valid LLM provider available")
                raise ImportError("No valid LLM provider available")
    
    def _setup_retriever(self):
        """Setup retriever from vector store service"""
        return vector_store_service.get_retriever()
    
    def _setup_memory(self):
        """Setup conversation memory"""
        try:
            from langchain.memory import ConversationBufferMemory
            return ConversationBufferMemory(memory_key='chat_history', return_messages=True)
        except ImportError as e:
            logger.error(f"❌ ConversationBufferMemory not available: {e}")
            raise ImportError("langchain memory not available")
    
    def _setup_qa_chain(self):
        """Setup ConversationalRetrievalChain"""
        try:
            from langchain.chains import ConversationalRetrievalChain
            return ConversationalRetrievalChain.from_llm(
                llm=self.llm,
                retriever=self.retriever,
                memory=self.memory,
                verbose=settings.LANGCHAIN_DEBUG  # Reduce debugging noise
            )
        except ImportError as e:
            logger.error(f"❌ ConversationalRetrievalChain not available: {e}")
            raise ImportError("langchain chains not available")
    
    def _preprocess_query(self, question: str) -> str:
        """Preprocess user query to improve vector search accuracy"""
        import re
        
        # Convert to lowercase for consistency
        processed = question.lower()
        
        # Remove common stop words that don't help with recipe matching
        stop_words = ['i', 'want', 'a', 'an', 'the', 'for', 'with', 'can', 'you', 'give', 'me', 'please', 'help']
        words = processed.split()
        words = [word for word in words if word not in stop_words]
        
        # Remove punctuation except spaces
        processed = ' '.join(words)
        processed = re.sub(r'[^\w\s]', '', processed)
        
        # Normalize multiple spaces
        processed = ' '.join(processed.split())
        
        logger.debug(f"πŸ”§ Query preprocessing: '{question}' β†’ '{processed}'")
        return processed

    def ask_question(self, user_question: str) -> str:
        """Ask a question using the conversational retrieval chain"""
        logger.info(f"❓ Processing: '{user_question[:60]}...'")
        
        try:
            # Preprocess query for better matching
            processed_query = self._preprocess_query(user_question)
            
            # Get context for token tracking
            document_retriever = getattr(self.qa_chain, 'retriever', None)
            retrieved_context = ""
            if document_retriever:
                # Use both queries for comprehensive results
                original_docs = document_retriever.invoke(user_question)
                processed_docs = document_retriever.invoke(processed_query)
                
                # Deduplicate documents
                seen_content = set()
                unique_documents = []
                for document in original_docs + processed_docs:
                    if document.page_content not in seen_content:
                        unique_documents.append(document)
                        seen_content.add(document.page_content)
                
                retrieved_context = "\n".join([doc.page_content for doc in unique_documents[:8]])
                logger.debug(f"πŸ“„ Retrieved {len(unique_documents)} unique documents")
            
            # Enhanced question for natural responses
            enhanced_question = f"""Based on the available recipe information, please answer this cooking question: "{user_question}"

Respond directly and naturally as if you're sharing your own culinary knowledge. If there's a specific recipe that matches the request, share the complete recipe with ingredients and step-by-step instructions in a friendly, conversational way."""
            
            result = self.qa_chain({"question": enhanced_question})
            generated_answer = result["answer"]
            
            self._log_token_usage(user_question, retrieved_context, generated_answer)
            
            logger.info(f"βœ… Response generated ({len(generated_answer)} chars)")
            return generated_answer
                
        except Exception as error:
            logger.error(f"❌ Error in ask_question: {str(error)}")
            return f"Sorry, I encountered an error: {str(error)}"
    
    def _count_tokens(self, text: str) -> int:
        """Count tokens in text (rough estimate for debugging)"""
        return len(text) // 4 if text else 0
    
    def _log_token_usage(self, question: str, context: str, response: str):
        """Log token usage for monitoring"""
        question_tokens = self._count_tokens(question)
        context_tokens = self._count_tokens(context)
        response_tokens = self._count_tokens(response)
        total_input_tokens = question_tokens + context_tokens
        
        logger.info(f"πŸ“Š Token Usage - Input:{total_input_tokens} (Q:{question_tokens}+C:{context_tokens}), Output:{response_tokens}")
        
        if context_tokens > 3000:
            logger.warning(f"⚠️ Large context detected: {context_tokens} tokens")
            
        return {
            "input_tokens": total_input_tokens, 
            "output_tokens": response_tokens, 
            "total_tokens": total_input_tokens + response_tokens
        }
    
    def clear_memory(self):
        """Clear conversation memory"""
        try:
            if hasattr(self.memory, 'clear'):
                self.memory.clear()
                logger.info("βœ… Memory cleared")
                return True
        except Exception as e:
            logger.warning(f"⚠️ Could not clear memory: {e}")
        return False

    def simple_chat_completion(self, user_message: str) -> str:
        """Simple chat completion without RAG - direct LLM response"""
        logger.info(f"πŸ’­ Simple chat: '{user_message[:50]}...'")
        
        try:
            llm_prompt = f"As a knowledgeable cooking expert, share your insights about {user_message}. Provide helpful culinary advice and recommendations:\n\n"
            
            llm_response = self.llm.invoke(llm_prompt) if hasattr(self.llm, 'invoke') else self.llm(llm_prompt)
            
            # Extract content based on response type
            if hasattr(llm_response, 'content'):
                generated_answer = llm_response.content
            elif isinstance(llm_response, str):
                generated_answer = llm_response.replace(llm_prompt, "").strip() if llm_prompt in llm_response else llm_response
            else:
                generated_answer = str(llm_response)
            
            # Validate and clean response
            generated_answer = generated_answer.strip()
            if not generated_answer or len(generated_answer) < 10:
                generated_answer = "I'd be happy to help with recipes! Ask me about specific ingredients or dishes."
            
            # Limit response length
            if len(generated_answer) > 300:
                answer_sentences = generated_answer.split('. ')
                generated_answer = '. '.join(answer_sentences[:2]) + '.' if len(answer_sentences) > 1 else generated_answer[:300]
            
            logger.info(f"βœ… Response generated ({len(generated_answer)} chars)")
            return generated_answer
            
        except Exception as error:
            logger.error(f"❌ Simple chat completion error: {str(error)}")
            return f"Sorry, I encountered an error: {str(error)}"

# Create global LLM service instance
llm_service = LLMService()