LLM / app.py
hmm183's picture
Update app.py
109fd2d verified
import os
# Force Hugging Face libraries to cache under /app/hf_cache (writable in Spaces)
os.environ["HF_HOME"] = "/app/hf_cache"
os.environ["TRANSFORMERS_CACHE"] = "/app/hf_cache"
os.environ["HF_DATASETS_CACHE"] = "/app/hf_cache"
os.environ["HF_METRICS_CACHE"] = "/app/hf_cache"
os.environ["SENTENCE_TRANSFORMERS_CACHE"] = "/app/hf_cache"
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.prompts import ChatPromptTemplate
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ────────────────
# Model configuration
# ────────────────
LLM_MODEL_NAME = "google/flan-t5-small" # Encoder‐decoder T5
EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
llm_tokenizer = None
llm_model = None
embeddings = None
user_vectorstores = {}
class LoadDocRequest(BaseModel):
user_id: str
text: str
class QueryRequest(BaseModel):
user_id: str
query: str
@app.on_event("startup")
async def load_models():
global llm_tokenizer, llm_model, embeddings
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load T5 tokenizer + model
llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME)
llm_model = AutoModelForSeq2SeqLM.from_pretrained(LLM_MODEL_NAME).to(device)
# Load embeddings
embeddings = HuggingFaceEmbeddings(model_name=EMBED_MODEL_NAME)
@app.post("/load_document")
async def load_document(data: LoadDocRequest):
user_id = data.user_id
text = data.text
persist_dir = f"./chroma_db_users/{user_id}/"
os.makedirs(persist_dir, exist_ok=True)
base_document = Document(page_content=text, metadata={"source": "upload"})
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
chunks = splitter.split_documents([base_document])
vectorstore = Chroma.from_documents(
chunks, embedding=embeddings, persist_directory=persist_dir
)
user_vectorstores[user_id] = vectorstore
return {"message": f"Loaded {len(chunks)} chunks for user {user_id}"}
@app.post("/query")
async def query(data: QueryRequest):
user_id = data.user_id
query_text = data.query
# Ensure vectorstore exists for this user
if user_id not in user_vectorstores:
persist_dir = f"./chroma_db_users/{user_id}/"
if os.path.exists(persist_dir):
user_vectorstores[user_id] = Chroma(
persist_directory=persist_dir,
embedding_function=embeddings
)
else:
return {"error": f"No vectorstore found for user {user_id}"}
vectorstore = user_vectorstores[user_id]
retriever = vectorstore.as_retriever()
docs = retriever.invoke(query_text)
context = "\n\n".join(doc.page_content for doc in docs)
prompt_template = ChatPromptTemplate.from_template(
"""Answer the question based ONLY on the context below:
Context: {context}
Question: {question}"""
)
prompt = prompt_template.format(context=context, question=query_text)
# Encode prompt and generate
input_ids = llm_tokenizer(prompt, return_tensors="pt").input_ids.to(llm_model.device)
output_ids = llm_model.generate(input_ids, max_new_tokens=200)
response = llm_tokenizer.decode(output_ids[0], skip_special_tokens=True)
# Remove the prompt prefix if the model echoes it
return {"response": response.replace(prompt, "").strip()}
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)