File size: 4,036 Bytes
f4b962e 3cc6209 109fd2d 9b3fc24 3cc6209 b44200d 109fd2d f4b962e d6795cf b44200d 109fd2d b44200d 3cc6209 109fd2d b44200d 3cc6209 b44200d 3cc6209 b44200d 3cc6209 b44200d 3cc6209 109fd2d b44200d 109fd2d 3cc6209 109fd2d b44200d 3cc6209 b44200d 109fd2d b44200d 3cc6209 109fd2d 3cc6209 d6795cf b44200d d6795cf b44200d d6795cf b44200d d6795cf b44200d d6795cf 109fd2d b44200d d6795cf 109fd2d b44200d f4b962e d6795cf 3cc6209 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 | import os
# Force Hugging Face libraries to cache under /app/hf_cache (writable in Spaces)
os.environ["HF_HOME"] = "/app/hf_cache"
os.environ["TRANSFORMERS_CACHE"] = "/app/hf_cache"
os.environ["HF_DATASETS_CACHE"] = "/app/hf_cache"
os.environ["HF_METRICS_CACHE"] = "/app/hf_cache"
os.environ["SENTENCE_TRANSFORMERS_CACHE"] = "/app/hf_cache"
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.prompts import ChatPromptTemplate
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ββββββββββββββββ
# Model configuration
# ββββββββββββββββ
LLM_MODEL_NAME = "google/flan-t5-small" # Encoderβdecoder T5
EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
llm_tokenizer = None
llm_model = None
embeddings = None
user_vectorstores = {}
class LoadDocRequest(BaseModel):
user_id: str
text: str
class QueryRequest(BaseModel):
user_id: str
query: str
@app.on_event("startup")
async def load_models():
global llm_tokenizer, llm_model, embeddings
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load T5 tokenizer + model
llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME)
llm_model = AutoModelForSeq2SeqLM.from_pretrained(LLM_MODEL_NAME).to(device)
# Load embeddings
embeddings = HuggingFaceEmbeddings(model_name=EMBED_MODEL_NAME)
@app.post("/load_document")
async def load_document(data: LoadDocRequest):
user_id = data.user_id
text = data.text
persist_dir = f"./chroma_db_users/{user_id}/"
os.makedirs(persist_dir, exist_ok=True)
base_document = Document(page_content=text, metadata={"source": "upload"})
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
chunks = splitter.split_documents([base_document])
vectorstore = Chroma.from_documents(
chunks, embedding=embeddings, persist_directory=persist_dir
)
user_vectorstores[user_id] = vectorstore
return {"message": f"Loaded {len(chunks)} chunks for user {user_id}"}
@app.post("/query")
async def query(data: QueryRequest):
user_id = data.user_id
query_text = data.query
# Ensure vectorstore exists for this user
if user_id not in user_vectorstores:
persist_dir = f"./chroma_db_users/{user_id}/"
if os.path.exists(persist_dir):
user_vectorstores[user_id] = Chroma(
persist_directory=persist_dir,
embedding_function=embeddings
)
else:
return {"error": f"No vectorstore found for user {user_id}"}
vectorstore = user_vectorstores[user_id]
retriever = vectorstore.as_retriever()
docs = retriever.invoke(query_text)
context = "\n\n".join(doc.page_content for doc in docs)
prompt_template = ChatPromptTemplate.from_template(
"""Answer the question based ONLY on the context below:
Context: {context}
Question: {question}"""
)
prompt = prompt_template.format(context=context, question=query_text)
# Encode prompt and generate
input_ids = llm_tokenizer(prompt, return_tensors="pt").input_ids.to(llm_model.device)
output_ids = llm_model.generate(input_ids, max_new_tokens=200)
response = llm_tokenizer.decode(output_ids[0], skip_special_tokens=True)
# Remove the prompt prefix if the model echoes it
return {"response": response.replace(prompt, "").strip()}
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)
|