import os # Force Hugging Face libraries to cache under /app/hf_cache (writable in Spaces) os.environ["HF_HOME"] = "/app/hf_cache" os.environ["TRANSFORMERS_CACHE"] = "/app/hf_cache" os.environ["HF_DATASETS_CACHE"] = "/app/hf_cache" os.environ["HF_METRICS_CACHE"] = "/app/hf_cache" os.environ["SENTENCE_TRANSFORMERS_CACHE"] = "/app/hf_cache" from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_community.vectorstores import Chroma from langchain_core.documents import Document from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_core.prompts import ChatPromptTemplate app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ──────────────── # Model configuration # ──────────────── LLM_MODEL_NAME = "google/flan-t5-small" # Encoder‐decoder T5 EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" llm_tokenizer = None llm_model = None embeddings = None user_vectorstores = {} class LoadDocRequest(BaseModel): user_id: str text: str class QueryRequest(BaseModel): user_id: str query: str @app.on_event("startup") async def load_models(): global llm_tokenizer, llm_model, embeddings device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load T5 tokenizer + model llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME) llm_model = AutoModelForSeq2SeqLM.from_pretrained(LLM_MODEL_NAME).to(device) # Load embeddings embeddings = HuggingFaceEmbeddings(model_name=EMBED_MODEL_NAME) @app.post("/load_document") async def load_document(data: LoadDocRequest): user_id = data.user_id text = data.text persist_dir = f"./chroma_db_users/{user_id}/" os.makedirs(persist_dir, exist_ok=True) base_document = Document(page_content=text, metadata={"source": "upload"}) splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) chunks = splitter.split_documents([base_document]) vectorstore = Chroma.from_documents( chunks, embedding=embeddings, persist_directory=persist_dir ) user_vectorstores[user_id] = vectorstore return {"message": f"Loaded {len(chunks)} chunks for user {user_id}"} @app.post("/query") async def query(data: QueryRequest): user_id = data.user_id query_text = data.query # Ensure vectorstore exists for this user if user_id not in user_vectorstores: persist_dir = f"./chroma_db_users/{user_id}/" if os.path.exists(persist_dir): user_vectorstores[user_id] = Chroma( persist_directory=persist_dir, embedding_function=embeddings ) else: return {"error": f"No vectorstore found for user {user_id}"} vectorstore = user_vectorstores[user_id] retriever = vectorstore.as_retriever() docs = retriever.invoke(query_text) context = "\n\n".join(doc.page_content for doc in docs) prompt_template = ChatPromptTemplate.from_template( """Answer the question based ONLY on the context below: Context: {context} Question: {question}""" ) prompt = prompt_template.format(context=context, question=query_text) # Encode prompt and generate input_ids = llm_tokenizer(prompt, return_tensors="pt").input_ids.to(llm_model.device) output_ids = llm_model.generate(input_ids, max_new_tokens=200) response = llm_tokenizer.decode(output_ids[0], skip_special_tokens=True) # Remove the prompt prefix if the model echoes it return {"response": response.replace(prompt, "").strip()} if __name__ == "__main__": import uvicorn uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)