Update app.py
Browse files
app.py
CHANGED
|
@@ -1,267 +1,104 @@
|
|
|
|
|
| 1 |
import os
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
# Ensure the directory exists
|
| 10 |
-
os.makedirs(os.environ["HF_HOME"], exist_ok=True)
|
| 11 |
-
|
| 12 |
-
# --- Flask and CORS ---
|
| 13 |
-
from flask import Flask, request, jsonify
|
| 14 |
-
from flask_cors import CORS
|
| 15 |
-
|
| 16 |
-
# --- LangChain and Hugging Face Libraries ---
|
| 17 |
-
# Note: We are NOT using Ollama directly in this app.py for Hugging Face Spaces.
|
| 18 |
-
# Instead, we are loading models directly via Hugging Face's transformers library.
|
| 19 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
| 20 |
-
import torch # For checking GPU availability and model dtype
|
| 21 |
-
|
| 22 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 23 |
from langchain_community.vectorstores import Chroma
|
| 24 |
from langchain_core.documents import Document
|
| 25 |
-
from langchain_core.prompts import ChatPromptTemplate
|
| 26 |
-
from langchain_core.output_parsers import StrOutputParser
|
| 27 |
-
from langchain_core.runnables import RunnablePassthrough
|
| 28 |
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
#
|
| 46 |
-
#
|
| 47 |
-
# -
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
# IMPORTANT: local_files_only=True means it will NOT try to download if not found.
|
| 98 |
-
# If you want it to download if not present, remove this line or set to False.
|
| 99 |
-
# For robust deployment, pre-caching and uploading the model is recommended.
|
| 100 |
-
model_kwargs={"local_files_only": False} # Set to False to allow download if not cached
|
| 101 |
-
)
|
| 102 |
-
print("Embedding Model initialized successfully!")
|
| 103 |
-
|
| 104 |
-
except Exception as e:
|
| 105 |
-
print(f"ERROR: An unexpected error occurred during model initialization: {e}")
|
| 106 |
-
llm_pipeline = None
|
| 107 |
-
embeddings = None
|
| 108 |
-
# Re-raise the exception to prevent the Flask app from starting if models fail to load
|
| 109 |
-
raise e
|
| 110 |
-
|
| 111 |
-
@app.route('/load_document', methods=['POST'])
|
| 112 |
-
def load_document():
|
| 113 |
-
"""
|
| 114 |
-
Load a document for a specific user into their dedicated persistent vector store.
|
| 115 |
-
The text is chunked for better retrieval.
|
| 116 |
-
"""
|
| 117 |
-
if not embeddings:
|
| 118 |
-
return jsonify({"error": "Embedding model not initialized. Server might be restarting or failed to load models."}), 500
|
| 119 |
-
|
| 120 |
-
data = request.get_json()
|
| 121 |
-
user_id = data.get("user_id") # Expecting a user_id from the client
|
| 122 |
-
text = data.get("text")
|
| 123 |
-
|
| 124 |
-
if not user_id:
|
| 125 |
-
return jsonify({"error": "User ID (user_id) is required to load a document."}), 400
|
| 126 |
-
if not text:
|
| 127 |
-
return jsonify({"error": "No text provided to load."}), 400
|
| 128 |
-
|
| 129 |
-
print(f"Loading document for user: {user_id}")
|
| 130 |
-
|
| 131 |
-
try:
|
| 132 |
-
# Create a unique persistence directory for each user's ChromaDB
|
| 133 |
-
# This will be within the Space's storage, which can be ephemeral on restarts.
|
| 134 |
-
persist_dir = f"{os.environ['HF_HOME']}/chroma_db_users/{user_id}/"
|
| 135 |
-
os.makedirs(persist_dir, exist_ok=True)
|
| 136 |
-
|
| 137 |
-
# Wrap the input text in a LangChain Document
|
| 138 |
-
base_document = Document(page_content=text, metadata={"user_id": user_id, "source": "user_upload"})
|
| 139 |
-
|
| 140 |
-
# Chunk the document for better retrieval performance
|
| 141 |
-
text_splitter = RecursiveCharacterTextSplitter(
|
| 142 |
-
chunk_size=1000, # Max characters per chunk
|
| 143 |
-
chunk_overlap=200, # Overlap between chunks to maintain context
|
| 144 |
-
length_function=len,
|
| 145 |
-
is_separator_regex=False,
|
| 146 |
-
)
|
| 147 |
-
chunks = text_splitter.split_documents([base_document])
|
| 148 |
-
|
| 149 |
-
# Create/overwrite the vector store for this specific user
|
| 150 |
-
# This will save to the user-specific directory on disk.
|
| 151 |
-
user_vectorstores[user_id] = Chroma.from_documents(
|
| 152 |
-
chunks, embedding=embeddings, persist_directory=persist_dir
|
| 153 |
-
)
|
| 154 |
-
|
| 155 |
-
print(f"Document loaded for user '{user_id}'. Chunks created: {len(chunks)} at {persist_dir}")
|
| 156 |
-
return jsonify({"message": f"Document loaded successfully for user '{user_id}'.", "chunks_created": len(chunks)})
|
| 157 |
-
except Exception as e:
|
| 158 |
-
print(f"Error loading document for user '{user_id}': {e}")
|
| 159 |
-
import traceback
|
| 160 |
-
traceback.print_exc() # Print full traceback for debugging
|
| 161 |
-
return jsonify({"error": f"Error loading document: {e}"}), 500
|
| 162 |
-
|
| 163 |
-
@app.route('/query', methods=['POST'])
|
| 164 |
-
def query():
|
| 165 |
-
"""
|
| 166 |
-
Query the currently loaded document for a specific user to summarize or answer a question.
|
| 167 |
-
"""
|
| 168 |
-
if not llm_pipeline or not embeddings:
|
| 169 |
-
return jsonify({"error": "Models not initialized. Server might be restarting or failed to load models."}), 500
|
| 170 |
-
|
| 171 |
-
data = request.get_json()
|
| 172 |
-
user_id = data.get("user_id")
|
| 173 |
-
query_text = data.get("query")
|
| 174 |
-
|
| 175 |
-
if not user_id:
|
| 176 |
-
return jsonify({"error": "User ID (user_id) is required to query."}), 400
|
| 177 |
-
if not query_text:
|
| 178 |
-
return jsonify({"error": "No query text provided."}), 400
|
| 179 |
-
|
| 180 |
-
print(f"Query received for user: {user_id}, Query: '{query_text}'")
|
| 181 |
-
|
| 182 |
-
# Retrieve the vector store for this specific user from the cache
|
| 183 |
-
current_user_vectorstore = user_vectorstores.get(user_id)
|
| 184 |
-
|
| 185 |
-
# If not in memory, attempt to load from disk for this user
|
| 186 |
-
if not current_user_vectorstore:
|
| 187 |
-
user_persist_dir = f"{os.environ['HF_HOME']}/chroma_db_users/{user_id}/"
|
| 188 |
-
if os.path.exists(user_persist_dir):
|
| 189 |
-
try:
|
| 190 |
-
# Load the existing vectorstore from disk
|
| 191 |
-
current_user_vectorstore = Chroma(persist_directory=user_persist_dir, embedding_function=embeddings)
|
| 192 |
-
user_vectorstores[user_id] = current_user_vectorstore # Cache it in memory for subsequent queries
|
| 193 |
-
print(f"Loaded existing vectorstore for user '{user_id}' from disk.")
|
| 194 |
-
except Exception as e:
|
| 195 |
-
print(f"Error loading vectorstore from disk for user '{user_id}': {e}")
|
| 196 |
-
return jsonify({"error": f"Failed to load document for user '{user_id}'. Please try loading it again or check server logs."}), 500
|
| 197 |
else:
|
| 198 |
-
return
|
| 199 |
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
# Create a prompt template geared toward Q&A based on context
|
| 204 |
-
prompt_template = ChatPromptTemplate.from_template(
|
| 205 |
-
"""Answer the question based ONLY on the following context. If the answer is not available in the provided context, politely state that you cannot find the answer in the provided information.
|
| 206 |
|
|
|
|
|
|
|
|
|
|
| 207 |
Context: {context}
|
|
|
|
|
|
|
|
|
|
| 208 |
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
# --- RAG Chain for Hugging Face Pipeline ---
|
| 214 |
-
# Get relevant context documents
|
| 215 |
-
retrieved_docs = retriever.invoke(query_text)
|
| 216 |
-
context_text = "\n\n".join([doc.page_content for doc in retrieved_docs])
|
| 217 |
-
|
| 218 |
-
# Format the prompt using the template and retrieved context
|
| 219 |
-
formatted_prompt = prompt_template.format(context=context_text, question=query_text)
|
| 220 |
-
|
| 221 |
-
# Use the Hugging Face pipeline directly for text generation
|
| 222 |
-
outputs = llm_pipeline(formatted_prompt)
|
| 223 |
-
|
| 224 |
-
# The output from the pipeline needs to be parsed based on its structure
|
| 225 |
-
# It's usually a list of dictionaries, with 'generated_text' key.
|
| 226 |
-
generated_text = outputs[0]['generated_text']
|
| 227 |
|
| 228 |
-
|
| 229 |
-
# This is a common challenge with text generation.
|
| 230 |
-
# A simple way is to find the query in the generated text and take what comes after.
|
| 231 |
-
response_start_index = generated_text.find(formatted_prompt)
|
| 232 |
-
if response_start_index != -1:
|
| 233 |
-
response = generated_text[response_start_index + len(formatted_prompt):].strip()
|
| 234 |
-
else:
|
| 235 |
-
response = generated_text.strip() # Fallback if prompt isn't found perfectly
|
| 236 |
-
|
| 237 |
-
# Further clean-up to remove any trailing prompt parts the model might generate
|
| 238 |
-
if response.startswith("Summary:"):
|
| 239 |
-
response = response[len("Summary:"):].strip()
|
| 240 |
-
if response.startswith("Answer:"):
|
| 241 |
-
response = response[len("Answer:"):].strip()
|
| 242 |
-
if response.startswith("Question:"):
|
| 243 |
-
response = response[len("Question:"):].strip()
|
| 244 |
-
if response.startswith("Context:"):
|
| 245 |
-
response = response[len("Context:"):].strip()
|
| 246 |
|
| 247 |
-
|
| 248 |
-
print(f"Response generated for user '{user_id}'.")
|
| 249 |
-
return jsonify({"response": response})
|
| 250 |
-
except Exception as e:
|
| 251 |
-
print(f"ERROR: An unexpected error occurred during query for user '{user_id}': {e}")
|
| 252 |
-
import traceback
|
| 253 |
-
traceback.print_exc()
|
| 254 |
-
return jsonify({"error": f"Error processing query: {e}"}), 500
|
| 255 |
-
|
| 256 |
if __name__ == "__main__":
|
| 257 |
-
|
| 258 |
-
initialize_models()
|
| 259 |
-
print(f"Starting Flask RAG MVP application on http://0.0.0.0:7860 (Hugging Face Spaces default port)")
|
| 260 |
-
print(f"Using LLM: {LLM_MODEL_NAME_HF}, Embeddings: {EMBEDDING_MODEL_NAME_HF}")
|
| 261 |
-
print("API endpoints:")
|
| 262 |
-
print(" - POST /load_document (Requires 'user_id' and 'text')")
|
| 263 |
-
print(" - POST /query (Requires 'user_id' and 'query')")
|
| 264 |
-
|
| 265 |
-
# Hugging Face Spaces typically runs on port 7860
|
| 266 |
-
app.run(host="0.0.0.0", port=7860)
|
| 267 |
-
|
|
|
|
| 1 |
+
# fastapi_app.py
|
| 2 |
import os
|
| 3 |
+
from fastapi import FastAPI, Request
|
| 4 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 5 |
+
from pydantic import BaseModel
|
| 6 |
+
import uvicorn
|
| 7 |
+
from typing import Dict
|
| 8 |
+
import torch
|
| 9 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 11 |
from langchain_community.vectorstores import Chroma
|
| 12 |
from langchain_core.documents import Document
|
|
|
|
|
|
|
|
|
|
| 13 |
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 14 |
+
from langchain_core.prompts import ChatPromptTemplate
|
| 15 |
+
import asyncio
|
| 16 |
+
|
| 17 |
+
# Set HF cache path
|
| 18 |
+
os.environ["TRANSFORMERS_CACHE"] = "./hf_cache"
|
| 19 |
+
|
| 20 |
+
app = FastAPI()
|
| 21 |
+
|
| 22 |
+
app.add_middleware(
|
| 23 |
+
CORSMiddleware,
|
| 24 |
+
allow_origins=["*"],
|
| 25 |
+
allow_credentials=True,
|
| 26 |
+
allow_methods=["*"],
|
| 27 |
+
allow_headers=["*"],
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
# -----------------------------
|
| 31 |
+
# Load models on startup
|
| 32 |
+
# -----------------------------
|
| 33 |
+
LLM_MODEL_NAME = "google/flan-t5-small" # Lightweight and fast on CPU
|
| 34 |
+
EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
| 35 |
+
|
| 36 |
+
llm_model = None
|
| 37 |
+
llm_tokenizer = None
|
| 38 |
+
embeddings = None
|
| 39 |
+
user_vectorstores: Dict[str, Chroma] = {}
|
| 40 |
+
|
| 41 |
+
class LoadDocRequest(BaseModel):
|
| 42 |
+
user_id: str
|
| 43 |
+
text: str
|
| 44 |
+
|
| 45 |
+
class QueryRequest(BaseModel):
|
| 46 |
+
user_id: str
|
| 47 |
+
query: str
|
| 48 |
+
|
| 49 |
+
@app.on_event("startup")
|
| 50 |
+
async def load_models():
|
| 51 |
+
global llm_model, llm_tokenizer, embeddings
|
| 52 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 53 |
+
llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME)
|
| 54 |
+
llm_model = AutoModelForCausalLM.from_pretrained(LLM_MODEL_NAME).to(device)
|
| 55 |
+
embeddings = HuggingFaceEmbeddings(model_name=EMBED_MODEL_NAME)
|
| 56 |
+
|
| 57 |
+
@app.post("/load_document")
|
| 58 |
+
async def load_document(data: LoadDocRequest):
|
| 59 |
+
user_id = data.user_id
|
| 60 |
+
text = data.text
|
| 61 |
+
|
| 62 |
+
persist_dir = f"./chroma_db_users/{user_id}/"
|
| 63 |
+
os.makedirs(persist_dir, exist_ok=True)
|
| 64 |
+
|
| 65 |
+
base_document = Document(page_content=text, metadata={"source": "upload"})
|
| 66 |
+
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
|
| 67 |
+
chunks = splitter.split_documents([base_document])
|
| 68 |
+
|
| 69 |
+
vectorstore = Chroma.from_documents(chunks, embedding=embeddings, persist_directory=persist_dir)
|
| 70 |
+
user_vectorstores[user_id] = vectorstore
|
| 71 |
+
return {"message": f"Loaded {len(chunks)} chunks for user {user_id}"}
|
| 72 |
+
|
| 73 |
+
@app.post("/query")
|
| 74 |
+
async def query(data: QueryRequest):
|
| 75 |
+
user_id = data.user_id
|
| 76 |
+
query_text = data.query
|
| 77 |
+
|
| 78 |
+
if user_id not in user_vectorstores:
|
| 79 |
+
persist_dir = f"./chroma_db_users/{user_id}/"
|
| 80 |
+
if os.path.exists(persist_dir):
|
| 81 |
+
user_vectorstores[user_id] = Chroma(persist_directory=persist_dir, embedding_function=embeddings)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
else:
|
| 83 |
+
return {"error": f"No vectorstore found for user {user_id}"}
|
| 84 |
|
| 85 |
+
vectorstore = user_vectorstores[user_id]
|
| 86 |
+
retriever = vectorstore.as_retriever()
|
| 87 |
+
docs = retriever.invoke(query_text)
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
+
context = "\n\n".join(doc.page_content for doc in docs)
|
| 90 |
+
prompt_template = ChatPromptTemplate.from_template(
|
| 91 |
+
"""Answer the question based ONLY on the context below:
|
| 92 |
Context: {context}
|
| 93 |
+
Question: {question}"""
|
| 94 |
+
)
|
| 95 |
+
prompt = prompt_template.format(context=context, question=query_text)
|
| 96 |
|
| 97 |
+
input_ids = llm_tokenizer(prompt, return_tensors="pt").input_ids.to(llm_model.device)
|
| 98 |
+
output_ids = llm_model.generate(input_ids, max_new_tokens=200)
|
| 99 |
+
response = llm_tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
+
return {"response": response.replace(prompt, "").strip()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
if __name__ == "__main__":
|
| 104 |
+
uvicorn.run("fastapi_app:app", host="0.0.0.0", port=7860, reload=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|