| import os |
|
|
| |
| os.environ["HF_HOME"] = "/app/hf_cache" |
| os.environ["TRANSFORMERS_CACHE"] = "/app/hf_cache" |
| os.environ["HF_DATASETS_CACHE"] = "/app/hf_cache" |
| os.environ["HF_METRICS_CACHE"] = "/app/hf_cache" |
| os.environ["SENTENCE_TRANSFORMERS_CACHE"] = "/app/hf_cache" |
|
|
| from fastapi import FastAPI |
| from fastapi.middleware.cors import CORSMiddleware |
| from pydantic import BaseModel |
| import torch |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
| from langchain_community.embeddings import HuggingFaceEmbeddings |
| from langchain_community.vectorstores import Chroma |
| from langchain_core.documents import Document |
| from langchain_text_splitters import RecursiveCharacterTextSplitter |
| from langchain_core.prompts import ChatPromptTemplate |
|
|
| app = FastAPI() |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| |
| |
| LLM_MODEL_NAME = "google/flan-t5-small" |
| EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" |
|
|
| llm_tokenizer = None |
| llm_model = None |
| embeddings = None |
| user_vectorstores = {} |
|
|
| class LoadDocRequest(BaseModel): |
| user_id: str |
| text: str |
|
|
| class QueryRequest(BaseModel): |
| user_id: str |
| query: str |
|
|
| @app.on_event("startup") |
| async def load_models(): |
| global llm_tokenizer, llm_model, embeddings |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| |
| llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME) |
| llm_model = AutoModelForSeq2SeqLM.from_pretrained(LLM_MODEL_NAME).to(device) |
|
|
| |
| embeddings = HuggingFaceEmbeddings(model_name=EMBED_MODEL_NAME) |
|
|
| @app.post("/load_document") |
| async def load_document(data: LoadDocRequest): |
| user_id = data.user_id |
| text = data.text |
|
|
| persist_dir = f"./chroma_db_users/{user_id}/" |
| os.makedirs(persist_dir, exist_ok=True) |
|
|
| base_document = Document(page_content=text, metadata={"source": "upload"}) |
| splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) |
| chunks = splitter.split_documents([base_document]) |
|
|
| vectorstore = Chroma.from_documents( |
| chunks, embedding=embeddings, persist_directory=persist_dir |
| ) |
| user_vectorstores[user_id] = vectorstore |
| return {"message": f"Loaded {len(chunks)} chunks for user {user_id}"} |
|
|
| @app.post("/query") |
| async def query(data: QueryRequest): |
| user_id = data.user_id |
| query_text = data.query |
|
|
| |
| if user_id not in user_vectorstores: |
| persist_dir = f"./chroma_db_users/{user_id}/" |
| if os.path.exists(persist_dir): |
| user_vectorstores[user_id] = Chroma( |
| persist_directory=persist_dir, |
| embedding_function=embeddings |
| ) |
| else: |
| return {"error": f"No vectorstore found for user {user_id}"} |
|
|
| vectorstore = user_vectorstores[user_id] |
| retriever = vectorstore.as_retriever() |
| docs = retriever.invoke(query_text) |
|
|
| context = "\n\n".join(doc.page_content for doc in docs) |
| prompt_template = ChatPromptTemplate.from_template( |
| """Answer the question based ONLY on the context below: |
| Context: {context} |
| Question: {question}""" |
| ) |
| prompt = prompt_template.format(context=context, question=query_text) |
|
|
| |
| input_ids = llm_tokenizer(prompt, return_tensors="pt").input_ids.to(llm_model.device) |
| output_ids = llm_model.generate(input_ids, max_new_tokens=200) |
| response = llm_tokenizer.decode(output_ids[0], skip_special_tokens=True) |
|
|
| |
| return {"response": response.replace(prompt, "").strip()} |
|
|
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True) |
|
|