Spaces:
Sleeping
Sleeping
File size: 6,587 Bytes
c4a1ce0 0cb2687 c77acde c4a1ce0 aa870e0 c4a1ce0 c77acde c4a1ce0 482d110 c4a1ce0 0cb2687 482d110 c4a1ce0 aa870e0 c4a1ce0 52e9bab c4a1ce0 52e9bab 0cb2687 52e9bab 0cb2687 52e9bab c4a1ce0 52e9bab 8529242 0cb2687 8529242 0cb2687 8529242 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 | from rag_pipeline import process_file, answer_query
from pydantic import BaseModel
from typing import List, Dict, Optional
from datetime import datetime
class ChatMessage(BaseModel):
question: str
answer: str
timestamp: datetime
class QueryRequest(BaseModel):
file_id: str
question: str
page: int
explainLike5: bool = False
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from dotenv import load_dotenv
from fastapi import Body
import uuid
import os
from langchain_chroma import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
import re
load_dotenv()
CHROMA_DIR = "./chroma_db"
embedding_model = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
app = FastAPI()
BASE_DIR = "files"
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
file_store = {}
# Add chat memory store
chat_memory: Dict[str, List[ChatMessage]] = {}
@app.get("/test")
async def test():
return {"message": "hello world!"}
@app.post("/upload")
async def upload(file: UploadFile = File(...)):
content = await file.read()
file_id = str(uuid.uuid4())
safe_filename = file.filename.replace(" ", "_")
full_filename = f"{file_id}_{safe_filename}"
save_path = os.path.join(BASE_DIR, full_filename)
os.makedirs(BASE_DIR, exist_ok=True)
with open(save_path, "wb") as f:
f.write(content)
retriever = process_file(content, safe_filename, file_id)
file_store[file_id] = retriever
return {"message": "File processed", "file_id": file_id}
@app.post("/query")
async def query_endpoint(request = Body(...)):
file_id = request.get("file_id")
question = request.get("question")
selected_text = request.get("selectedText")
explain_like_5 = request.get("explainLike5", False)
if not file_id or not question:
raise HTTPException(status_code=422, detail="Missing file_id or question")
retriever_path = f"{CHROMA_DIR}/{file_id}"
if not os.path.exists(retriever_path):
raise HTTPException(status_code=404, detail="Vectorstore for this file_id not found.")
try:
# Initialize vectorstore with metadata filtering
vectorstore = Chroma(
embedding_function=embedding_model,
persist_directory=retriever_path
)
# Configure retriever with MMR search
retriever = vectorstore.as_retriever(
search_type="mmr",
search_kwargs={
"k": 4,
"fetch_k": 8,
"lambda_mult": 0.7,
}
)
# First, get context around selected text if it exists
contexts = []
if selected_text:
selected_results = retriever.invoke(selected_text)
contexts.extend([doc.page_content for doc in selected_results])
# Then get context for the question
question_results = retriever.invoke(question)
contexts.extend([doc.page_content for doc in question_results])
# Remove duplicates while preserving order
contexts = list(dict.fromkeys(contexts))
# Format the context with clear section separation
formatted_context = ""
if selected_text:
formatted_context += f"Selected Text Context:\n{selected_text}\n\n"
formatted_context += "Related Document Contexts:\n" + "\n---\n".join(
re.sub(r"\s+", " ", context.strip())
for context in contexts
)
# Add chat history to context if it exists
if file_id in chat_memory and chat_memory[file_id]:
chat_history = "\n\nPrevious Conversation:\n"
for msg in chat_memory[file_id][-3:]: # Include last 3 exchanges
chat_history += f"Q: {msg.question}\nA: {msg.answer}\n\n"
formatted_context = chat_history + formatted_context
# Get the answer using the enhanced context
answer = answer_query(question, formatted_context, explain_like_5)
# Store the Q&A in chat memory
if file_id not in chat_memory:
chat_memory[file_id] = []
chat_memory[file_id].append(ChatMessage(
question=question,
answer=answer,
timestamp=datetime.now()
))
return {
"answer": answer,
"context_used": formatted_context # Optionally return context for debugging
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing query: {str(e)}")
@app.delete("/delete/{file_id}")
async def delete_file(file_id: str):
try:
# 1. Delete from Chroma DB
chroma_path = f"{CHROMA_DIR}/{file_id}"
if os.path.exists(chroma_path):
try:
# Load and delete the collection
vectorstore = Chroma(
embedding_function=embedding_model,
persist_directory=chroma_path
)
vectorstore.delete_collection()
# Delete the directory
import shutil
shutil.rmtree(chroma_path)
except Exception as e:
print(f"Error deleting Chroma DB: {str(e)}")
# 2. Delete the actual file from disk
file_pattern = f"{file_id}_*"
matching_files = []
for filename in os.listdir(BASE_DIR):
if filename.startswith(file_id):
file_path = os.path.join(BASE_DIR, filename)
try:
os.remove(file_path)
matching_files.append(filename)
except Exception as e:
print(f"Error deleting file {filename}: {str(e)}")
# 3. Clear chat memory for this file
if file_id in chat_memory:
del chat_memory[file_id]
if not matching_files and not os.path.exists(chroma_path):
raise HTTPException(
status_code=404,
detail=f"No files found for file_id: {file_id}"
)
return {
"message": "File, embeddings, and chat history deleted successfully",
"deleted_files": matching_files,
"embeddings_deleted": os.path.exists(chroma_path)
}
except Exception as e:
if isinstance(e, HTTPException):
raise e
raise HTTPException(
status_code=500,
detail=f"Error during deletion: {str(e)}"
) |