File size: 4,036 Bytes
f4b962e
3cc6209
109fd2d
9b3fc24
 
 
 
 
3cc6209
 
b44200d
 
 
109fd2d
f4b962e
d6795cf
 
 
b44200d
 
 
109fd2d
b44200d
 
 
 
 
 
 
 
3cc6209
 
 
109fd2d
b44200d
 
 
3cc6209
b44200d
3cc6209
b44200d
 
 
 
 
 
 
 
 
 
 
3cc6209
b44200d
3cc6209
109fd2d
b44200d
109fd2d
3cc6209
109fd2d
b44200d
 
 
 
 
 
 
 
 
 
 
 
 
 
3cc6209
 
 
b44200d
 
 
 
 
 
 
 
109fd2d
b44200d
 
 
3cc6209
109fd2d
 
3cc6209
d6795cf
b44200d
d6795cf
b44200d
 
 
d6795cf
b44200d
 
 
d6795cf
b44200d
 
 
d6795cf
109fd2d
b44200d
 
 
d6795cf
109fd2d
b44200d
f4b962e
d6795cf
3cc6209
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os

# Force Hugging Face libraries to cache under /app/hf_cache (writable in Spaces)
os.environ["HF_HOME"] = "/app/hf_cache"
os.environ["TRANSFORMERS_CACHE"] = "/app/hf_cache"
os.environ["HF_DATASETS_CACHE"] = "/app/hf_cache"
os.environ["HF_METRICS_CACHE"] = "/app/hf_cache"
os.environ["SENTENCE_TRANSFORMERS_CACHE"] = "/app/hf_cache"

from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.prompts import ChatPromptTemplate

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# ────────────────
# Model configuration
# ────────────────
LLM_MODEL_NAME = "google/flan-t5-small"                   # Encoder‐decoder T5
EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"

llm_tokenizer = None
llm_model = None
embeddings = None
user_vectorstores = {}

class LoadDocRequest(BaseModel):
    user_id: str
    text: str

class QueryRequest(BaseModel):
    user_id: str
    query: str

@app.on_event("startup")
async def load_models():
    global llm_tokenizer, llm_model, embeddings
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load T5 tokenizer + model
    llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME)
    llm_model = AutoModelForSeq2SeqLM.from_pretrained(LLM_MODEL_NAME).to(device)

    # Load embeddings
    embeddings = HuggingFaceEmbeddings(model_name=EMBED_MODEL_NAME)

@app.post("/load_document")
async def load_document(data: LoadDocRequest):
    user_id = data.user_id
    text = data.text

    persist_dir = f"./chroma_db_users/{user_id}/"
    os.makedirs(persist_dir, exist_ok=True)

    base_document = Document(page_content=text, metadata={"source": "upload"})
    splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
    chunks = splitter.split_documents([base_document])

    vectorstore = Chroma.from_documents(
        chunks, embedding=embeddings, persist_directory=persist_dir
    )
    user_vectorstores[user_id] = vectorstore
    return {"message": f"Loaded {len(chunks)} chunks for user {user_id}"}

@app.post("/query")
async def query(data: QueryRequest):
    user_id = data.user_id
    query_text = data.query

    # Ensure vectorstore exists for this user
    if user_id not in user_vectorstores:
        persist_dir = f"./chroma_db_users/{user_id}/"
        if os.path.exists(persist_dir):
            user_vectorstores[user_id] = Chroma(
                persist_directory=persist_dir,
                embedding_function=embeddings
            )
        else:
            return {"error": f"No vectorstore found for user {user_id}"}

    vectorstore = user_vectorstores[user_id]
    retriever = vectorstore.as_retriever()
    docs = retriever.invoke(query_text)

    context = "\n\n".join(doc.page_content for doc in docs)
    prompt_template = ChatPromptTemplate.from_template(
        """Answer the question based ONLY on the context below:
Context: {context}
Question: {question}"""
    )
    prompt = prompt_template.format(context=context, question=query_text)

    # Encode prompt and generate
    input_ids = llm_tokenizer(prompt, return_tensors="pt").input_ids.to(llm_model.device)
    output_ids = llm_model.generate(input_ids, max_new_tokens=200)
    response = llm_tokenizer.decode(output_ids[0], skip_special_tokens=True)

    # Remove the prompt prefix if the model echoes it
    return {"response": response.replace(prompt, "").strip()}

if __name__ == "__main__":
    import uvicorn
    uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)