Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| import datetime | |
| from mistralai import Mistral # Updated import | |
| from chroma_db_utils import get_relevant_passage | |
| # Constants | |
| MAX_RETRIES = 3 | |
| RETRY_DELAY = 1 # Initial delay in seconds | |
| MODEL_NAME = "mistral-large-latest" # Mistral's latest model | |
| REQUESTS_PER_MINUTE = 10 # Mistral's rate limit (adjust based on your tier) | |
| REQUEST_INTERVAL = 60 / REQUESTS_PER_MINUTE | |
| # Hardcoded Mistral API Key (NOT RECOMMENDED for production) | |
| MISTRAL_API_KEY = "9x8duC1VJ7n5uEwdV8nG6bmFEIqCftKn" | |
| # Initialize Mistral client | |
| mistral_client = Mistral(api_key=MISTRAL_API_KEY) # Updated client initialization | |
| def make_rag_prompt(query: str, relevant_passage: str) -> list[dict]: | |
| """ | |
| Creates a chat prompt for the Mistral RAG model. | |
| """ | |
| escaped = relevant_passage.replace("'", "").replace('"', "").replace("\n", " ") | |
| return [ | |
| { | |
| "role": "system", | |
| "content": "You are a helpful bot that answers questions using the provided context. " | |
| "If the context is irrelevant, try to answer it from your own database and knowledge. Always provide easy and understandable explanation" | |
| }, | |
| { | |
| "role": "user", | |
| "content": f"QUESTION: {query}\n\nREFERENCE TEXT: {escaped}\n\nANSWER:" | |
| } | |
| ] | |
| def generate_answer(prompt: list[dict]) -> str: | |
| """ | |
| Calls the Mistral API with retries and rate limiting. | |
| """ | |
| for attempt in range(MAX_RETRIES): | |
| start_time = datetime.datetime.now() | |
| print(f"{start_time}: Making Mistral API request (attempt {attempt + 1}/{MAX_RETRIES})...") | |
| try: | |
| # Use the new method: client.chat.complete | |
| response = mistral_client.chat.complete( | |
| model=MODEL_NAME, | |
| messages=prompt, | |
| temperature=0.3 | |
| ) | |
| end_time = datetime.datetime.now() | |
| print(f"{end_time}: Mistral API request successful. Time taken: {end_time - start_time}") | |
| return response.choices[0].message.content | |
| except Exception as e: | |
| if "rate limit" in str(e).lower() or attempt < MAX_RETRIES - 1: | |
| delay = RETRY_DELAY * (2 ** attempt) # Exponential backoff | |
| print(f"API error: {str(e)}. Retrying in {delay} seconds...") | |
| time.sleep(delay) | |
| else: | |
| raise | |
| raise Exception("Max retries exceeded for Mistral API request.") | |
| def handle_query(query: str, db, n_results: int = 5) -> str: | |
| """ | |
| Handles a user query using Mistral AI. | |
| """ | |
| relevant_passages = get_relevant_passage(query, db, n_results) | |
| relevant_passage_str = " ".join(relevant_passages) | |
| chat_prompt = make_rag_prompt(query, relevant_passage_str) | |
| return generate_answer(chat_prompt) |