File size: 7,610 Bytes
5c89b18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
from fastapi import FastAPI, Query
from fastapi.responses import StreamingResponse
from sentence_transformers import SentenceTransformer
import chromadb
import os
import json
import asyncio
from pdf_to_vector_store import pdf_to_vector_store, safe_print
from llm_interface import LLMInterface
from typing import List, Tuple
from dotenv import load_dotenv

load_dotenv()

app = FastAPI()

# List of PDF files with absolute paths
pdf_files = [
    "D:/CSE299/PythonProject1/data/Diet.pdf",
    "D:/CSE299/PythonProject1/data/Goal.pdf",
    "D:/CSE299/PythonProject1/data/Precausions.pdf",
    "D:/CSE299/PythonProject1/data/Scientific.pdf",
    "D:/CSE299/PythonProject1/data/Templates.pdf",
    "D:/CSE299/PythonProject1/data/Workout.pdf"
]

# ChromaDB persistent storage path
CHROMA_DB_PATH = "chroma_db"

# Initialize backend components
valid_pdf_files = [pdf for pdf in pdf_files if os.path.exists(pdf)]
if not valid_pdf_files:
    raise ValueError("No valid PDF files found. Check file paths.")

# Get API keys using os.getenv
hf_token = os.getenv("HF_TOKEN")
if not hf_token:
    raise ValueError("HF_TOKEN environment variable not set.")

gemini_api_key = os.getenv("GEMINI_API_KEY")
if not gemini_api_key:
    raise ValueError("GEMINI_API_KEY environment variable not set.")

nebius_api_key = hf_token
if not nebius_api_key:
    raise ValueError("NEBIUS_API_KEY environment variable not set.")

fireworks_api_key = hf_token
if not fireworks_api_key:
    raise ValueError("FIREWORKS_API_KEY environment variable not set.")

llm = LLMInterface(hf_api_token=hf_token, gemini_api_key=gemini_api_key, nebius_api_key=nebius_api_key,
                   fireworks_api_key=fireworks_api_key)
embedding_model = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
all_chunks, all_metadata, collection = pdf_to_vector_store(valid_pdf_files, CHROMA_DB_PATH, llm.tokenizer)


def retrieve_top_k(query: str, collection: chromadb.Collection, model: SentenceTransformer, k: int = 3) -> List[
    Tuple[str, float, dict]]:
    """

    Retrieves top-k relevant chunks based on a query using ChromaDB.

    """
    query_embedding = model.encode([query])[0].tolist()
    results = collection.query(
        query_embeddings=[query_embedding],
        n_results=k
    )
    documents = results['documents'][0]
    distances = results['distances'][0]
    metadatas = results['metadatas'][0]
    similarities = [(doc, max(0.0, 1 - dist), meta) for doc, dist, meta in zip(documents, distances, metadatas)]
    return similarities


def is_context_relevant(top_k_results: List[Tuple[str, float, dict]], threshold: float = 0.5) -> bool:
    """

    Checks if the retrieved context is relevant based on similarity scores.

    """
    if not top_k_results:
        return False
    max_similarity = max(score for _, score, _ in top_k_results)
    return max_similarity >= threshold


async def stream_responses(query: str, target_model: str):
    """

    Streams responses from the specified model with RAG.

    """
    # Find the target model
    model = next((m for m in llm.models if m["name"] == target_model), None)
    if not model:
        yield f"data: {json.dumps({'event': 'response_chunk', 'model': 'system', 'chunk': 'Invalid model specified.'})}\n\n"
        yield f"data: {json.dumps({'event': 'done'})}\n\n"
        return

    # Handle conversational queries

    # Retrieve context and generate response
    top_k_results = retrieve_top_k(query, collection, embedding_model, k=3)
    context = "\n".join([chunk for chunk, _, _ in top_k_results])
    use_context = is_context_relevant(top_k_results, threshold=0.5)
    prompt = llm.prompt_template.format(context=context if use_context else "", question=query)

    try:
        safe_print(f"Streaming response for model: {model['name']}")
        if model["type"] == "gemini":
            stream = model["client"].generate_content(
                prompt,
                generation_config={"max_output_tokens": 2048, "temperature": 0.5},
                safety_settings={
                    "HARM_CATEGORY_HARASSMENT": "BLOCK_NONE",
                    "HARM_CATEGORY_HATE_SPEECH": "BLOCK_NONE",
                    "HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_NONE",
                    "HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_NONE"
                },
                stream=True
            )
            full_response = ""
            for chunk in stream:
                if chunk.text:
                    safe_print(f"Gemini chunk: {chunk.text}")
                    full_response += chunk.text
                    yield f"data: {json.dumps({'event': 'response_chunk', 'model': model['name'], 'chunk': full_response})}\n\n"
                    await asyncio.sleep(0.01)
        elif model["type"] == "fireworks":
            stream = model["client"].chat.completions.create(
                model=model["name"],
                messages=[
                    {"role": "system",
                     "content": "You are a helpful assistant specializing in workout, diet, and gym recommendations. For questions unrelated to these topics, respond saying you dont have permission from Aurnobb to say those ."},
                    {"role": "user", "content": prompt}
                ],
                max_tokens=2048,
                temperature=0.5,
                stream=True
            )
            full_response = ""
            for chunk in stream:
                if chunk.choices[0].delta.content:
                    safe_print(f"Fireworks chunk: {chunk.choices[0].delta.content}")
                    full_response += chunk.choices[0].delta.content
                    yield f"data: {json.dumps({'event': 'response_chunk', 'model': model['name'], 'chunk': full_response})}\n\n"
                    await asyncio.sleep(0.01)
        else:  # nebius
            stream = model["client"].chat.completions.create(
                model=model["name"],
                messages=[
                    {"role": "system",
                     "content": "You are a helpful assistant specializing in workout, diet, and gym recommendations. For questions unrelated to these topics, respond conversationally without directly answering the question, and steer the conversation back to fitness, diet, or gym topics."},
                    {"role": "user", "content": prompt}
                ],
                max_tokens=2048,
                temperature=0.5,
                stream=True
            )
            full_response = ""
            for chunk in stream:
                if chunk.choices[0].delta.content:
                    safe_print(f"Nebius chunk: {chunk.choices[0].delta.content}")
                    full_response += chunk.choices[0].delta.content
                    yield f"data: {json.dumps({'event': 'response_chunk', 'model': model['name'], 'chunk': full_response})}\n\n"
                    await asyncio.sleep(0.01)
    except Exception as e:
        error_msg = f"Error in {model['name']}: {str(e)}"
        safe_print(error_msg)
        yield f"data: {json.dumps({'event': 'response_chunk', 'model': model['name'], 'chunk': error_msg})}\n\n"

    yield f"data: {json.dumps({'event': 'done'})}\n\n"


@app.get("/chat")
async def chat(query: str = Query(...), model: str = Query(...)):
    """

    Streams chat responses for a given query and target model using SSE.

    """
    return StreamingResponse(stream_responses(query, model), media_type="text/event-stream")