| |
|
| | |
| | from typing import List, Dict, Any, Optional, Tuple |
| | from datetime import datetime |
| | import logging |
| | from config.config import settings |
| | import asyncio |
| | from io import StringIO |
| | import pandas as pd |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | class ConversationManager: |
| | """Manages conversation history and context""" |
| | def __init__(self): |
| | self.conversations: Dict[str, List[Dict[str, Any]]] = {} |
| | self.max_history = 1 |
| | |
| | def add_interaction( |
| | self, |
| | session_id: str, |
| | user_input: str, |
| | response: str, |
| | context: Optional[Dict[str, Any]] = None |
| | ) -> None: |
| | if session_id not in self.conversations: |
| | self.conversations[session_id] = [] |
| | |
| | self.conversations[session_id].append({ |
| | 'timestamp': datetime.now().isoformat(), |
| | 'user_input': user_input, |
| | 'response': response, |
| | 'context': context |
| | }) |
| | |
| | if len(self.conversations[session_id]) > self.max_history: |
| | self.conversations[session_id] = self.conversations[session_id][-self.max_history:] |
| | |
| | def get_history(self, session_id: str) -> List[Dict[str, Any]]: |
| | return self.conversations.get(session_id, []) |
| | |
| | def clear_history(self, session_id: str) -> None: |
| | if session_id in self.conversations: |
| | del self.conversations[session_id] |
| |
|
| | class ChatService: |
| | def __init__( |
| | self, |
| | model_service, |
| | data_service, |
| | pdf_service, |
| | faq_service |
| | ): |
| | self.model = model_service.model |
| | self.tokenizer = model_service.tokenizer |
| | self.data_service = data_service |
| | self.pdf_service = pdf_service |
| | self.faq_service = faq_service |
| | self.conversation_manager = ConversationManager() |
| |
|
| | async def search_all_sources( |
| | self, |
| | query: str, |
| | top_k: int = 3 |
| | ) -> Dict[str, List[Dict[str, Any]]]: |
| | """Search across all available data sources""" |
| | try: |
| | print("-----------------------------") |
| | print("starting searches .... ") |
| | |
| | |
| | products = await self.data_service.search(query, top_k) |
| | pdfs = await self.pdf_service.search(query, top_k) |
| | faqs = await self.faq_service.search_faqs(query, top_k) |
| |
|
| | results = { |
| | 'products': products or [], |
| | 'documents': pdfs or [], |
| | 'faqs': faqs or [] |
| | } |
| | |
| | print("Search results:", results) |
| | return results |
| | |
| | except Exception as e: |
| | logger.error(f"Error searching sources: {e}") |
| | return {'products': [], 'documents': [], 'faqs': []} |
| |
|
| | def construct_system_prompt(self, context: str) -> str: |
| | """Constructs the system message.""" |
| | return ( |
| | "You are a friendly bot named: Oma Erna, specializing in Bofrost products and content. Use only the context from this prompt. " |
| | "Return comprehensive German answers. If possible add product IDs from context. Do not make up information. The context is is truth. " |
| | "Use the following context (product descriptions and information) for answers:\n\n" |
| | f"{context}\n\n" |
| | ) |
| |
|
| | def construct_prompt( |
| | self, |
| | user_input: str, |
| | context: str, |
| | chat_history: List[Tuple[str, str]], |
| | max_history_turns: int = 1 |
| | ) -> str: |
| | """Constructs the full prompt.""" |
| | system_message = self.construct_system_prompt(context) |
| | prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>" |
| |
|
| | for user_msg, assistant_msg in chat_history[-max_history_turns:]: |
| | prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{user_msg}<|eot_id|>" |
| | prompt += f"<|start_header_id|>assistant<|end_header_id|>\n\n{assistant_msg}<|eot_id|>" |
| |
|
| | prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{user_input}<|eot_id|>" |
| | prompt += "<|start_header_id|>assistant<|end_header_id|>\n\n" |
| |
|
| | return prompt |
| |
|
| | def build_context( |
| | self, |
| | search_results: Dict[str, List[Dict[str, Any]]], |
| | chat_history: List[Dict[str, Any]] |
| | ) -> str: |
| | """Build context for the model from search results and chat history""" |
| | context_parts = [] |
| | |
| | |
| | if search_results.get('products'): |
| | products = search_results['products'][:2] |
| | for product in products: |
| | context_parts.append( |
| | f"Produkt: {product['Name']}\n" |
| | f"Beschreibung: {product['Description']}\n" |
| | f"Preis: {product['Price']}€\n" |
| | f"Kategorie: {product['ProductCategory']}" |
| | ) |
| | |
| | |
| | if search_results.get('documents'): |
| | docs = search_results['documents'][:2] |
| | for doc in docs: |
| | context_parts.append( |
| | f"Aus Dokument '{doc['source']}' (Seite {doc['page']}):\n" |
| | f"{doc['text']}" |
| | ) |
| | |
| | |
| | if search_results.get('faqs'): |
| | faqs = search_results['faqs'][:2] |
| | for faq in faqs: |
| | context_parts.append( |
| | f"FAQ:\n" |
| | f"Frage: {faq['question']}\n" |
| | f"Antwort: {faq['answer']}" |
| | ) |
| | |
| | |
| | if chat_history: |
| | print("--- historiy--- ") |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | print("\n\n".join(context_parts)) |
| | return "\n\n".join(context_parts) |
| |
|
| | async def chat( |
| | self, |
| | user_input: str, |
| | session_id: Any, |
| | max_length: int = 8000 |
| | ) -> Tuple[str, List[Tuple[str, str]], Dict[str, List[Dict[str, Any]]]]: |
| | """Main chat method that coordinates the entire conversation flow.""" |
| | try: |
| | if not isinstance(session_id, str): |
| | session_id = str(session_id) |
| | |
| | chat_history_raw = self.conversation_manager.get_history(session_id) |
| | chat_history = [ |
| | (entry['user_input'], entry['response']) for entry in chat_history_raw |
| | ] |
| | |
| | search_results = await self.search_all_sources(user_input) |
| | print(search_results) |
| | |
| | context = self.build_context(search_results, chat_history_raw) |
| | prompt = self.construct_prompt(user_input, context, chat_history) |
| | response = self.generate_response(prompt, max_length) |
| | |
| | self.conversation_manager.add_interaction( |
| | session_id, |
| | user_input, |
| | response, |
| | {'search_results': search_results} |
| | ) |
| | |
| | formatted_history = [ |
| | (entry['user_input'], entry['response']) |
| | for entry in self.conversation_manager.get_history(session_id) |
| | ] |
| | |
| | return response, formatted_history, search_results |
| | |
| | except Exception as e: |
| | logger.error(f"Error in chat: {e}") |
| | raise |
| |
|
| | def generate_response( |
| | self, |
| | prompt: str, |
| | max_length: int = 1000 |
| | ) -> str: |
| | """Generate response using the language model""" |
| | try: |
| | print(prompt) |
| | inputs = self.tokenizer( |
| | prompt, |
| | return_tensors="pt", |
| | truncation=True, |
| | max_length=4096 |
| | ).to(settings.DEVICE) |
| | |
| | outputs = self.model.generate( |
| | **inputs, |
| | max_length=max_length, |
| | num_return_sequences=1, |
| | temperature=0.7, |
| | top_p=0.9, |
| | do_sample=True, |
| | no_repeat_ngram_size=3, |
| | early_stopping=False |
| | ) |
| |
|
| | input_ids = self.tokenizer.encode(prompt, return_tensors="pt", truncation=True, max_length=4096).to("cpu") |
| |
|
| | response = self.tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True) |
| |
|
| | |
| | response = response.replace("<|assistant|>", "").strip() |
| | |
| | return response.strip() |
| | |
| | except Exception as e: |
| | logger.error(f"Error generating response: {e}") |
| | raise |