Spaces:
Build error
Build error
| # services/chat_service.py | |
| from typing import List, Dict, Any, Optional, Tuple | |
| from datetime import datetime | |
| import logging | |
| from config.config import settings | |
| logger = logging.getLogger(__name__) | |
| class ConversationManager: | |
| """Manages conversation history and context""" | |
| def __init__(self): | |
| self.conversations: Dict[str, List[Dict[str, Any]]] = {} | |
| self.max_history = 10 | |
| def add_interaction( | |
| self, | |
| session_id: str, | |
| user_input: str, | |
| response: str, | |
| context: Optional[Dict[str, Any]] = None | |
| ) -> None: | |
| if session_id not in self.conversations: | |
| self.conversations[session_id] = [] | |
| self.conversations[session_id].append({ | |
| 'timestamp': datetime.now().isoformat(), | |
| 'user_input': user_input, | |
| 'response': response, | |
| 'context': context | |
| }) | |
| # Trim history if needed | |
| if len(self.conversations[session_id]) > self.max_history: | |
| self.conversations[session_id] = self.conversations[session_id][-self.max_history:] | |
| def get_history(self, session_id: str) -> List[Dict[str, Any]]: | |
| return self.conversations.get(session_id, []) | |
| def clear_history(self, session_id: str) -> None: | |
| if session_id in self.conversations: | |
| del self.conversations[session_id] | |
| class ChatService: | |
| """Main chat service that coordinates responses""" | |
| def __init__( | |
| self, | |
| model_service, | |
| data_service, | |
| pdf_service, | |
| faq_service | |
| ): | |
| self.model = model_service.model | |
| self.tokenizer = model_service.tokenizer | |
| self.data_service = data_service | |
| self.pdf_service = pdf_service | |
| self.faq_service = faq_service | |
| self.conversation_manager = ConversationManager() | |
| async def search_all_sources( | |
| self, | |
| query: str, | |
| top_k: int = 3 | |
| ) -> Dict[str, List[Dict[str, Any]]]: | |
| """Search across all available data sources""" | |
| try: | |
| # Run searches in parallel | |
| product_task = asyncio.create_task( | |
| self.data_service.search(query, top_k) | |
| ) | |
| pdf_task = asyncio.create_task( | |
| self.pdf_service.search(query, top_k) | |
| ) | |
| faq_task = asyncio.create_task( | |
| self.faq_service.search_faqs(query, top_k) | |
| ) | |
| # Gather results | |
| products, pdfs, faqs = await asyncio.gather( | |
| product_task, pdf_task, faq_task | |
| ) | |
| return { | |
| 'products': products, | |
| 'documents': pdfs, | |
| 'faqs': faqs | |
| } | |
| except Exception as e: | |
| logger.error(f"Error searching sources: {e}") | |
| return {'products': [], 'documents': [], 'faqs': []} | |
| def build_context( | |
| self, | |
| search_results: Dict[str, List[Dict[str, Any]]], | |
| chat_history: List[Dict[str, Any]] | |
| ) -> str: | |
| """Build context for the model from search results and chat history""" | |
| context_parts = [] | |
| # Add relevant products | |
| if search_results.get('products'): | |
| products = search_results['products'][:2] # Limit to top 2 products | |
| for product in products: | |
| context_parts.append( | |
| f"Produkt: {product['Name']}\n" | |
| f"Beschreibung: {product['Description']}\n" | |
| f"Preis: {product['Price']}€\n" | |
| f"Kategorie: {product['ProductCategory']}" | |
| ) | |
| # Add relevant PDF content | |
| if search_results.get('documents'): | |
| docs = search_results['documents'][:2] | |
| for doc in docs: | |
| context_parts.append( | |
| f"Aus Dokument '{doc['source']}' (Seite {doc['page']}):\n" | |
| f"{doc['text']}" | |
| ) | |
| # Add relevant FAQs | |
| if search_results.get('faqs'): | |
| faqs = search_results['faqs'][:2] | |
| for faq in faqs: | |
| context_parts.append( | |
| f"FAQ:\n" | |
| f"Frage: {faq['question']}\n" | |
| f"Antwort: {faq['answer']}" | |
| ) | |
| # Add recent chat history | |
| if chat_history: | |
| recent_history = chat_history[-3:] # Last 3 interactions | |
| history_text = "\n".join( | |
| f"User: {h['user_input']}\nAssistant: {h['response']}" | |
| for h in recent_history | |
| ) | |
| context_parts.append(f"Letzte Interaktionen:\n{history_text}") | |
| return "\n\n".join(context_parts) | |
| async def generate_response( | |
| self, | |
| prompt: str, | |
| max_length: int = 1000 | |
| ) -> str: | |
| """Generate response using the language model""" | |
| try: | |
| inputs = self.tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=4096 | |
| ).to(settings.DEVICE) | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_length=max_length, | |
| num_return_sequences=1, | |
| temperature=0.7, | |
| top_p=0.9, | |
| do_sample=True, | |
| no_repeat_ngram_size=3, | |
| early_stopping=True | |
| ) | |
| response = self.tokenizer.decode( | |
| outputs[0], | |
| skip_special_tokens=True | |
| ) | |
| return response.strip() | |
| except Exception as e: | |
| logger.error(f"Error generating response: {e}") | |
| raise | |
| async def chat( | |
| self, | |
| user_input: str, | |
| session_id: str, | |
| max_length: int = 1000 | |
| ) -> Tuple[str, List[Dict[str, Any]]]: | |
| """Main chat method that coordinates the entire conversation flow""" | |
| try: | |
| # Get chat history | |
| chat_history = self.conversation_manager.get_history(session_id) | |
| # Search all sources | |
| search_results = await self.search_all_sources(user_input) | |
| # Build context | |
| context = self.build_context(search_results, chat_history) | |
| # Create prompt | |
| prompt = ( | |
| f"Context:\n{context}\n\n" | |
| f"User: {user_input}\n" | |
| "Assistant:" | |
| ) | |
| # Generate response | |
| response = await self.generate_response(prompt, max_length) | |
| # Store interaction | |
| self.conversation_manager.add_interaction( | |
| session_id, | |
| user_input, | |
| response, | |
| {'search_results': search_results} | |
| ) | |
| return response, search_results | |
| except Exception as e: | |
| logger.error(f"Error in chat: {e}") | |
| raise |