Spaces:
Runtime error
Runtime error
| import logging | |
| import re | |
| import os | |
| from typing import List | |
| from datetime import datetime | |
| from fastapi.responses import JSONResponse | |
| from script.vector_db import IndexManager | |
| from llama_index.core.llms import MessageRole | |
| from core.chat.engine import Engine | |
| from core.chat.chatstore import ChatStore | |
| from core.parser import clean_text, update_response, sort_and_renumber_sources | |
| from service.dto import ChatMessage | |
| from pymongo.mongo_client import MongoClient | |
| class ChatCompletionService: | |
| def __init__(self, session_id: str, user_request: str, titles: List = None, type_bot: str = "general"): | |
| self.session_id = session_id | |
| self.user_request = user_request | |
| self.titles = titles | |
| self.type_bot = type_bot | |
| self.client = MongoClient(os.getenv("MONGO_URI")) | |
| self.engine = Engine() | |
| self.index_manager = IndexManager() | |
| self.chatstore = ChatStore() | |
| def generate_completion(self): | |
| if not self._ping_mongo(): | |
| return JSONResponse(status_code=500, content="Database Error: Unable to connect to MongoDB") | |
| try: | |
| # Load and retrieve chat engine with appropriate index | |
| index = self.index_manager.load_existing_indexes() | |
| chat_engine = self._get_chat_engine(index) | |
| # Generate chat response | |
| response = chat_engine.chat(self.user_request) | |
| sources = response.sources | |
| number_reference_sorted = self._extract_sorted_references(response) | |
| contents, metadata_collection, scores = self._process_sources(sources, number_reference_sorted) | |
| # Update response and renumber sources | |
| response = update_response(str(response)) | |
| contents = sort_and_renumber_sources(contents) | |
| # Add contents to metadata | |
| metadata_collection = self._attach_contents_to_metadata(contents, metadata_collection) | |
| # Save the message to chat store | |
| self._store_message_in_chatstore(response, metadata_collection) | |
| except Exception as e: | |
| logging.error(f"An error occurred in generate text: {e}") | |
| return JSONResponse( | |
| status_code=500, | |
| content=f"An internal server error occurred: {e}" | |
| ) | |
| try: | |
| if self.type_bot == "specific": | |
| self._save_chat_history_to_db(response, metadata_collection) | |
| return str(response), metadata_collection, scores | |
| except Exception as e: | |
| logging.error(f"An error occurred while saving chat history: {e}") | |
| return JSONResponse( | |
| status_code=500, | |
| content=f"An internal server error occurred while saving chat history: {e}" | |
| ) | |
| def _ping_mongo(self): | |
| try: | |
| self.client.admin.command("ping") | |
| print("Pinged your deployment. Successfully connected to MongoDB!") | |
| return True | |
| except Exception as e: | |
| logging.error(f"MongoDB connection failed: {e}") | |
| return False | |
| def _get_chat_engine(self, index): | |
| if self.type_bot == "general": | |
| return self.engine.get_chat_engine(self.session_id, index) | |
| return self.engine.get_chat_engine(self.session_id, index, self.titles, self.type_bot) | |
| def _extract_sorted_references(self, response): | |
| number_reference = list(set(re.findall(r"\[(\d+)\]", str(response)))) | |
| return sorted(number_reference) | |
| def _process_sources(self, sources, number_reference_sorted): | |
| contents, metadata_collection, scores = [], [], [] | |
| if not number_reference_sorted: | |
| print("There are no references") | |
| return contents, metadata_collection, scores | |
| for number in number_reference_sorted: | |
| number = int(number) | |
| if sources and sources[0].get("raw_output"): | |
| node = dict(sources[0])["raw_output"].source_nodes | |
| if 0 <= number - 1 < len(node): | |
| content = clean_text(node[number - 1].node.get_text()) | |
| contents.append(content) | |
| metadata = dict(node[number - 1].node.metadata) | |
| metadata_collection.append(metadata) | |
| score = node[number - 1].score | |
| scores.append(score) | |
| else: | |
| print(f"Invalid reference number: {number}") | |
| else: | |
| print("No sources available") | |
| return contents, metadata_collection, scores | |
| def _attach_contents_to_metadata(self, contents, metadata_collection): | |
| for i in range(min(len(contents), len(metadata_collection))): | |
| metadata_collection[i]["content"] = re.sub(r"source \d+:", "", contents[i]) | |
| return metadata_collection | |
| def _store_message_in_chatstore(self, response, metadata_collection): | |
| message = ChatMessage( | |
| role=MessageRole.ASSISTANT, | |
| content=response, | |
| metadata=metadata_collection | |
| ) | |
| self.chatstore.delete_last_message(self.session_id) | |
| self.chatstore.add_message(self.session_id, message) | |
| self.chatstore.clean_message(self.session_id) | |
| def _save_chat_history_to_db(self, response, metadata_collection): | |
| chat_history_db = [ | |
| ChatMessage( | |
| role=MessageRole.SYSTEM, | |
| content=self.user_request, | |
| timestamp=datetime.now(), | |
| payment="free" if self.type_bot == "general" else None, | |
| ), | |
| ChatMessage( | |
| role=MessageRole.ASSISTANT, | |
| content=response, | |
| metadata=metadata_collection, | |
| timestamp=datetime.now(), | |
| payment="free" if self.type_bot == "general" else None, | |
| ), | |
| ] | |
| chat_history_json = [message.model_dump() for message in chat_history_db] | |
| db = self.client["bot_database"] # Replace with your database name | |
| collection = db[self.session_id] # Replace with your collection name | |
| result = collection.insert_many(chat_history_json) | |
| print("Data inserted with record ids", result.inserted_ids) | |
| # Example usage | |
| def generate_completion_non_streaming(session_id, user_request, titles=None, type_bot="general"): | |
| chat_service = ChatCompletionService(session_id, user_request, titles, type_bot) | |
| return chat_service.generate_completion() | |