File size: 3,467 Bytes
775a7d0
9d21791
 
 
 
 
775a7d0
 
 
 
9d21791
775a7d0
9d21791
 
 
 
775a7d0
 
9d21791
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
775a7d0
 
 
9d21791
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
775a7d0
 
9d21791
 
 
 
 
 
 
 
775a7d0
9d21791
775a7d0
9d21791
775a7d0
 
 
 
9d21791
 
775a7d0
9d21791
775a7d0
 
9d21791
 
 
 
 
775a7d0
 
 
 
 
9d21791
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import os
from time import time
from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from dotenv import load_dotenv
import google.generativeai as genai

from rag_store import ingest_documents, search_knowledge

# -----------------------
# Setup
# -----------------------
load_dotenv()
genai.configure(api_key=os.getenv("GEMINI_API_KEY"))

app = FastAPI(
    title="Gemini RAG FastAPI",
    docs_url="/docs",
    redoc_url="/redoc"
)

# -----------------------
# CORS
# -----------------------
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

# -----------------------
# Frontend
# -----------------------
app.mount("/frontend", StaticFiles(directory="frontend"), name="frontend")

# -----------------------
# Cache (protect quota)
# -----------------------
CACHE_TTL = 300  # seconds
answer_cache = {}

# -----------------------
# Models
# -----------------------
class PromptRequest(BaseModel):
    prompt: str

# -----------------------
# Routes
# -----------------------

@app.get("/", response_class=HTMLResponse)
def serve_ui():
    with open("frontend/index.html", "r", encoding="utf-8") as f:
        return f.read()

# -----------------------
# Upload
# -----------------------
@app.post("/upload")
async def upload(files: list[UploadFile] = File(...)):
    try:
        chunks = ingest_documents(files)
        return {"message": f"Indexed {chunks} chunks from {len(files)} file(s)."}
    except Exception as e:
        return JSONResponse(status_code=400, content={"error": str(e)})

# -----------------------
# Ask
# -----------------------
@app.post("/ask")
async def ask(data: PromptRequest):
    prompt_key = data.prompt.strip().lower()
    now = time()

    # 🔁 Cache
    if prompt_key in answer_cache:
        ts, cached = answer_cache[prompt_key]
        if now - ts < CACHE_TTL:
            return cached

    results = search_knowledge(data.prompt)
    if not results:
        response = {
            "answer": "I don't know based on the provided documents.",
            "confidence": 0.0,
            "citations": []
        }
        answer_cache[prompt_key] = (now, response)
        return response

    context = "\n\n".join(r["text"] for r in results)

    prompt = f"""
Answer strictly using the context below.
If not found, say "I don't know".

Context:
{context}

Question:
{data.prompt}
"""

    try:
        model = genai.GenerativeModel("gemini-2.5-flash")
        llm_response = model.generate_content(prompt)

        response = {
            "answer": llm_response.text,
            "confidence": round(min(1.0, len(results) / 5), 2),
            "citations": [
                {"source": r["metadata"]["source"], "page": r["metadata"]["page"]}
                for r in results
            ]
        }

        answer_cache[prompt_key] = (now, response)
        return response

    except Exception as e:
        return JSONResponse(
            status_code=429,
            content={"error": "LLM quota exceeded. Please wait and retry."}
        )

# -----------------------
# Summarize
# -----------------------
@app.post("/summarize")
async def summarize():
    return await ask(PromptRequest(
        prompt="Summarize the uploaded documents in 5 concise bullet points."
    ))