Spaces:
Sleeping
Sleeping
File size: 2,835 Bytes
af7a842 76685ff af7a842 a7094fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import os
import time
import datetime
from mistralai import Mistral # Updated import
from chroma_db_utils import get_relevant_passage
# Constants
MAX_RETRIES = 3
RETRY_DELAY = 1 # Initial delay in seconds
MODEL_NAME = "mistral-large-latest" # Mistral's latest model
REQUESTS_PER_MINUTE = 10 # Mistral's rate limit (adjust based on your tier)
REQUEST_INTERVAL = 60 / REQUESTS_PER_MINUTE
# Hardcoded Mistral API Key (NOT RECOMMENDED for production)
MISTRAL_API_KEY = "9x8duC1VJ7n5uEwdV8nG6bmFEIqCftKn"
# Initialize Mistral client
mistral_client = Mistral(api_key=MISTRAL_API_KEY) # Updated client initialization
def make_rag_prompt(query: str, relevant_passage: str) -> list[dict]:
"""
Creates a chat prompt for the Mistral RAG model.
"""
escaped = relevant_passage.replace("'", "").replace('"', "").replace("\n", " ")
return [
{
"role": "system",
"content": "You are a helpful bot that answers questions using the provided context. "
"If the context is irrelevant, try to answer it from your own database and knowledge. Always provide easy and understandable explanation"
},
{
"role": "user",
"content": f"QUESTION: {query}\n\nREFERENCE TEXT: {escaped}\n\nANSWER:"
}
]
def generate_answer(prompt: list[dict]) -> str:
"""
Calls the Mistral API with retries and rate limiting.
"""
for attempt in range(MAX_RETRIES):
start_time = datetime.datetime.now()
print(f"{start_time}: Making Mistral API request (attempt {attempt + 1}/{MAX_RETRIES})...")
try:
# Use the new method: client.chat.complete
response = mistral_client.chat.complete(
model=MODEL_NAME,
messages=prompt,
temperature=0.3
)
end_time = datetime.datetime.now()
print(f"{end_time}: Mistral API request successful. Time taken: {end_time - start_time}")
return response.choices[0].message.content
except Exception as e:
if "rate limit" in str(e).lower() or attempt < MAX_RETRIES - 1:
delay = RETRY_DELAY * (2 ** attempt) # Exponential backoff
print(f"API error: {str(e)}. Retrying in {delay} seconds...")
time.sleep(delay)
else:
raise
raise Exception("Max retries exceeded for Mistral API request.")
def handle_query(query: str, db, n_results: int = 5) -> str:
"""
Handles a user query using Mistral AI.
"""
relevant_passages = get_relevant_passage(query, db, n_results)
relevant_passage_str = " ".join(relevant_passages)
chat_prompt = make_rag_prompt(query, relevant_passage_str)
return generate_answer(chat_prompt) |