| | import asyncio |
| | import json |
| | import re |
| | from typing import List, Dict |
| | import faiss |
| | import httpx |
| | import numpy as np |
| | import pandas as pd |
| | from sqlalchemy.ext.asyncio import AsyncSession |
| | from starlette.websockets import WebSocket |
| | from transformers import pipeline |
| |
|
| | from project.bot.models import MessagePair |
| | from project.config import settings |
| |
|
| |
|
| | class SearchBot: |
| | chat_history = [] |
| |
|
| | |
| | |
| |
|
| | def __init__(self, memory=None): |
| | if memory is None: |
| | memory = [] |
| | self.chat_history = memory |
| |
|
| | @staticmethod |
| | def _cls_pooling(model_output): |
| | return model_output.last_hidden_state[:, 0] |
| |
|
| | @staticmethod |
| | async def enrich_information_from_google(search_word: str) -> str: |
| | url = "https://places.googleapis.com/v1/places:searchText" |
| | headers = { |
| | "Content-Type": "application/json", |
| | "X-Goog-Api-Key": settings.GOOGLE_PLACES_API_KEY, |
| | "X-Goog-FieldMask": "places.shortFormattedAddress,places.websiteUri,places.internationalPhoneNumber," |
| | "places.googleMapsUri,places.photos" |
| | } |
| | data = { |
| | "textQuery": f"{search_word} in Javea", |
| | "languageCode": "nl", |
| | "maxResultCount": 1, |
| |
|
| | } |
| | async with httpx.AsyncClient() as client: |
| | response = await client.post(url, headers=headers, content=json.dumps(data)) |
| | place_response = response.json() |
| | place_response = place_response['places'][0] |
| | photo_name = place_response.get('photos') |
| | photo_uri = None |
| | if photo_name: |
| | async with httpx.AsyncClient() as client: |
| | response = await client.get( |
| | f'https://places.googleapis.com/v1/{photo_name[0]["name"]}/media?maxWidthPx=350&key={settings.GOOGLE_PLACES_API_KEY}') |
| | photo_response = response.json() |
| | photo_uri = photo_response.get('photoUri') |
| | google_maps_uri = place_response.get('googleMapsUri') |
| | phone_number = place_response.get('internationalPhoneNumber') |
| | formatted_address = place_response.get('shortFormattedAddress') |
| | website_uri = place_response.get('websiteUri') |
| | if not google_maps_uri: |
| | return search_word |
| | enriched_word = f'<a class="extraDataLink" href="{google_maps_uri}" target="_blank">{search_word}</a><div class="tooltip-elem">' |
| | if photo_uri: |
| | enriched_word += f'<img src="{photo_uri}" alt="Image" class="tooltip-img">' |
| | if formatted_address: |
| | enriched_word += f'<p><a href="{google_maps_uri}" target="_blank">{formatted_address}</a></p>' |
| | if website_uri: |
| | enriched_word += f'<p><a href="{website_uri}">Google Maps URI</a></p>' |
| | if phone_number: |
| | phone_str = re.sub(r' ', '', phone_number) |
| | enriched_word += f'<p><a href="tel:{phone_str}">Phone number</a></p>' |
| | enriched_word += f"</div>" |
| | return enriched_word |
| |
|
| | async def analyze_full_response(self) -> str: |
| | assistant_message = self.chat_history.pop()['content'] |
| | nlp = pipeline("ner", model=settings.NLP_MODEL, tokenizer=settings.NLP_TOKENIZER, aggregation_strategy="simple") |
| | ner_result = nlp(assistant_message) |
| | analyzed_assistant_message = assistant_message |
| | for entity in ner_result: |
| | if entity['entity_group'] in ("LOC", "ORG", "MISC") and entity['word'] != "Javea": |
| | enriched_information = await self.enrich_information_from_google(entity['word']) |
| | analyzed_assistant_message = analyzed_assistant_message.replace(entity['word'], enriched_information, 1) |
| | return "ENRICHED:" + analyzed_assistant_message |
| |
|
| | async def _convert_to_embeddings(self, text_list): |
| | encoded_input = settings.INFO_TOKENIZER( |
| | text_list, padding=True, truncation=True, return_tensors="pt" |
| | ) |
| | encoded_input = {k: v.to(settings.device) for k, v in encoded_input.items()} |
| | model_output = settings.INFO_MODEL(**encoded_input) |
| | return self._cls_pooling(model_output).cpu().detach().numpy().astype('float32') |
| |
|
| | @staticmethod |
| | async def _get_context_data(user_query: list[float]) -> list[dict]: |
| | radius = 5 |
| | _, distances, indices = settings.FAISS_INDEX.range_search(user_query, radius) |
| | indices_distances_df = pd.DataFrame({'index': indices, 'distance': distances}) |
| | filtered_data_df = settings.products_dataset.iloc[indices].copy() |
| | filtered_data_df.loc[:, 'distance'] = indices_distances_df['distance'].values |
| | sorted_data_df: pd.DataFrame = filtered_data_df.sort_values(by='distance').reset_index(drop=True) |
| | sorted_data_df = sorted_data_df.drop('distance', axis=1) |
| | data = sorted_data_df.head(3).to_dict(orient='records') |
| | cleaned_data = [] |
| | for chunk in data: |
| | if "Comments:" in chunk['chunks']: |
| | cleaned_data.append(chunk) |
| | return cleaned_data |
| |
|
| | @staticmethod |
| | async def create_context_str(context: List[Dict]) -> str: |
| | context_str = '' |
| | for i, chunk in enumerate(context): |
| | context_str += f'{i + 1}) {chunk["chunks"]}' |
| | return context_str |
| |
|
| | async def _rag(self, context: List[Dict], query: str, session: AsyncSession, country: str): |
| | if context: |
| | context_str = await self.create_context_str(context) |
| | assistant_message = {"role": 'assistant', "content": context_str} |
| | self.chat_history.append(assistant_message) |
| | content = settings.PROMPT |
| | else: |
| | content = settings.EMPTY_PROMPT |
| | user_message = {"role": 'user', "content": query} |
| | self.chat_history.append(user_message) |
| | messages = [ |
| | { |
| | 'role': 'system', |
| | 'content': content |
| | }, |
| | ] |
| | messages = messages + self.chat_history |
| |
|
| | stream = await settings.OPENAI_CLIENT.chat.completions.create( |
| | messages=messages, |
| | temperature=0.1, |
| | n=1, |
| | model="gpt-3.5-turbo", |
| | stream=True |
| | ) |
| | response = '' |
| | async for chunk in stream: |
| | if chunk.choices[0].delta.content is not None: |
| | chunk_content = chunk.choices[0].delta.content |
| | response += chunk_content |
| | yield response |
| | await asyncio.sleep(0.02) |
| | assistant_message = {"role": 'assistant', "content": response} |
| | self.chat_history.append(assistant_message) |
| | try: |
| | session.add(MessagePair(user_message=query, bot_response=response, country=country)) |
| | except Exception as e: |
| | print(e) |
| |
|
| | async def ask_and_send(self, data: Dict, websocket: WebSocket, session: AsyncSession): |
| | query = data['query'] |
| | country = data['country'] |
| | transformed_query = await self._convert_to_embeddings(query) |
| | context = await self._get_context_data(transformed_query) |
| | try: |
| | async for chunk in self._rag(context, query, session, country): |
| | await websocket.send_text(chunk) |
| | analyzing = await self.analyze_full_response() |
| | await websocket.send_text(analyzing) |
| | except Exception: |
| | await self.emergency_db_saving(session) |
| |
|
| | @staticmethod |
| | async def emergency_db_saving(session: AsyncSession): |
| | await session.commit() |
| | await session.close() |
| |
|