rag-reader / main.py
viraj
RAG app upload
f09e297
from rag_pipeline import process_file, answer_query
from pydantic import BaseModel
class QueryRequest(BaseModel):
file_id: str
question: str
page: int
explainLike5: bool = False
from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from dotenv import load_dotenv
from fastapi import Body
import uuid
import os
from langchain_chroma import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
import re
load_dotenv()
CHROMA_DIR = "./chroma_db"
embedding_model = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
app = FastAPI()
BASE_DIR = "files"
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
file_store = {}
@app.get("/test")
async def test():
return {"message": "hello world!"}
@app.post("/upload")
async def upload(file: UploadFile = File(...)):
content = await file.read()
file_id = str(uuid.uuid4())
safe_filename = file.filename.replace(" ", "_")
full_filename = f"{file_id}_{safe_filename}"
save_path = os.path.join(BASE_DIR, full_filename)
os.makedirs(BASE_DIR, exist_ok=True)
with open(save_path, "wb") as f:
f.write(content)
retriever = process_file(content, safe_filename, file_id)
file_store[file_id] = retriever
return {"message": "File processed", "file_id": file_id}
@app.post("/query")
async def query_endpoint(request = Body(...)):
file_id = request.get("file_id")
question = request.get("question")
selected_text = request.get("selectedText")
explain_like_5 = request.get("explainLike5", False)
if not file_id or not question:
return {"error": "Missing file_id or question"}
retriever_path = f"{CHROMA_DIR}/{file_id}"
# Load retriever from disk
if not os.path.exists(retriever_path):
return {"error": "Vectorstore for this file_id not found."}
vectorstore = Chroma(
embedding_function=embedding_model,
persist_directory=retriever_path
)
retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 4})
retrieved_docs = retriever.invoke(selected_text or question)
retrieved_context = "\n\n".join(
re.sub(r"\s+", " ", doc.page_content.strip()) for doc in retrieved_docs
)
combined_context = f"User selected this:\n\"{selected_text}\"\n\nRelated parts from the document:\n{retrieved_context}"
print("Combined context", combined_context)
answer = answer_query(question, combined_context, explain_like_5)
return {"answer": answer}