File size: 2,835 Bytes
af7a842
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76685ff
af7a842
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7094fa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os
import time
import datetime
from mistralai import Mistral  # Updated import
from chroma_db_utils import get_relevant_passage

# Constants
MAX_RETRIES = 3
RETRY_DELAY = 1  # Initial delay in seconds
MODEL_NAME = "mistral-large-latest"  # Mistral's latest model
REQUESTS_PER_MINUTE = 10  # Mistral's rate limit (adjust based on your tier)
REQUEST_INTERVAL = 60 / REQUESTS_PER_MINUTE

# Hardcoded Mistral API Key (NOT RECOMMENDED for production)
MISTRAL_API_KEY = "9x8duC1VJ7n5uEwdV8nG6bmFEIqCftKn"

# Initialize Mistral client
mistral_client = Mistral(api_key=MISTRAL_API_KEY)  # Updated client initialization

def make_rag_prompt(query: str, relevant_passage: str) -> list[dict]:
    """
    Creates a chat prompt for the Mistral RAG model.
    """
    escaped = relevant_passage.replace("'", "").replace('"', "").replace("\n", " ")
    
    return [
        {
            "role": "system",
            "content": "You are a helpful bot that answers questions using the provided context. "
                       "If the context is irrelevant, try to answer it from your own database and knowledge. Always provide easy and understandable explanation"
        },
        {
            "role": "user",
            "content": f"QUESTION: {query}\n\nREFERENCE TEXT: {escaped}\n\nANSWER:"
        }
    ]

def generate_answer(prompt: list[dict]) -> str:
    """
    Calls the Mistral API with retries and rate limiting.
    """
    for attempt in range(MAX_RETRIES):
        start_time = datetime.datetime.now()
        print(f"{start_time}: Making Mistral API request (attempt {attempt + 1}/{MAX_RETRIES})...")
        
        try:
            # Use the new method: client.chat.complete
            response = mistral_client.chat.complete(
                model=MODEL_NAME,
                messages=prompt,
                temperature=0.3
            )
            end_time = datetime.datetime.now()
            print(f"{end_time}: Mistral API request successful. Time taken: {end_time - start_time}")
            
            return response.choices[0].message.content
        
        except Exception as e:
            if "rate limit" in str(e).lower() or attempt < MAX_RETRIES - 1:
                delay = RETRY_DELAY * (2 ** attempt)  # Exponential backoff
                print(f"API error: {str(e)}. Retrying in {delay} seconds...")
                time.sleep(delay)
            else:
                raise

    raise Exception("Max retries exceeded for Mistral API request.")

def handle_query(query: str, db, n_results: int = 5) -> str:
    """
    Handles a user query using Mistral AI.
    """
    relevant_passages = get_relevant_passage(query, db, n_results)
    relevant_passage_str = " ".join(relevant_passages)
    chat_prompt = make_rag_prompt(query, relevant_passage_str)
    return generate_answer(chat_prompt)