Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from flask import Flask, request, jsonify
|
| 2 |
+
from flask_cors import CORS
|
| 3 |
+
import os
|
| 4 |
+
# No requests import needed for Ollama connection check if not using Ollama
|
| 5 |
+
|
| 6 |
+
# Import Hugging Face Transformers
|
| 7 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
| 8 |
+
import torch # For checking GPU availability
|
| 9 |
+
|
| 10 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings # Using HF Embeddings now
|
| 11 |
+
from langchain_community.vectorstores import Chroma
|
| 12 |
+
from langchain_core.documents import Document
|
| 13 |
+
from langchain_core.prompts import ChatPromptTemplate
|
| 14 |
+
from langchain_core.output_parsers import StrOutputParser
|
| 15 |
+
from langchain_core.runnables import RunnablePassthrough
|
| 16 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 17 |
+
|
| 18 |
+
app = Flask(__name__)
|
| 19 |
+
CORS(app)
|
| 20 |
+
|
| 21 |
+
# --- Model Configuration for Hugging Face Transformers ---
|
| 22 |
+
# CHOOSE A SMALLER MODEL! Gemma 4B is too large for free tier usually.
|
| 23 |
+
# 'google/gemma-2b-it' is a good conversational starting point.
|
| 24 |
+
LLM_MODEL_NAME_HF = "google/gemma-2b-it"
|
| 25 |
+
EMBEDDING_MODEL_NAME_HF = "sentence-transformers/all-MiniLM-L6-v2" # Standard small embedding model
|
| 26 |
+
|
| 27 |
+
# Global variables for models
|
| 28 |
+
llm_pipeline = None # Will be a Hugging Face pipeline
|
| 29 |
+
embeddings = None # Will be a HuggingFaceEmbeddings instance
|
| 30 |
+
|
| 31 |
+
# --- User-specific Vector Stores Cache ---
|
| 32 |
+
user_vectorstores = {}
|
| 33 |
+
|
| 34 |
+
def initialize_models():
|
| 35 |
+
"""
|
| 36 |
+
Initialize Hugging Face models (LLM pipeline and Embeddings).
|
| 37 |
+
"""
|
| 38 |
+
global llm_pipeline, embeddings
|
| 39 |
+
print("Initializing Hugging Face models...")
|
| 40 |
+
try:
|
| 41 |
+
# Determine device for LLM: Use GPU if available, otherwise CPU
|
| 42 |
+
device = 0 if torch.cuda.is_available() else -1
|
| 43 |
+
print(f"Using device: {'cuda' if device == 0 else 'cpu'}")
|
| 44 |
+
|
| 45 |
+
# Initialize LLM Pipeline
|
| 46 |
+
# This will download the model weights (gemma-2b-it is ~5GB)
|
| 47 |
+
# It's recommended to do this once at startup.
|
| 48 |
+
print(f"Loading LLM: {LLM_MODEL_NAME_HF}...")
|
| 49 |
+
tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME_HF)
|
| 50 |
+
model = AutoModelForCausalLM.from_pretrained(LLM_MODEL_NAME_HF, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32)
|
| 51 |
+
llm_pipeline = pipeline(
|
| 52 |
+
"text-generation",
|
| 53 |
+
model=model,
|
| 54 |
+
tokenizer=tokenizer,
|
| 55 |
+
max_new_tokens=500, # Limit response length
|
| 56 |
+
device=device,
|
| 57 |
+
# Add other generation parameters as needed, e.g., do_sample=True, top_p=0.9, temperature=0.7
|
| 58 |
+
)
|
| 59 |
+
print("LLM Pipeline initialized successfully!")
|
| 60 |
+
|
| 61 |
+
# Initialize Hugging Face Embeddings
|
| 62 |
+
print(f"Loading Embedding Model: {EMBEDDING_MODEL_NAME_HF}...")
|
| 63 |
+
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME_HF)
|
| 64 |
+
print("Embedding Model initialized successfully!")
|
| 65 |
+
|
| 66 |
+
except Exception as e:
|
| 67 |
+
print(f"ERROR: An unexpected error occurred during model initialization: {e}")
|
| 68 |
+
llm_pipeline = None
|
| 69 |
+
embeddings = None
|
| 70 |
+
# Raise the exception to prevent the app from starting if models fail to load
|
| 71 |
+
raise e
|
| 72 |
+
|
| 73 |
+
# --- Helper function to adapt HF pipeline to LangChain's LLM interface ---
|
| 74 |
+
# LangChain's pipeline.py can convert HF pipelines but requires some setup.
|
| 75 |
+
# For simplicity, we'll manually wrap it in the RAG chain
|
| 76 |
+
# We will use it directly in the RAG chain's invoke step.
|
| 77 |
+
|
| 78 |
+
@app.route('/load_document', methods=['POST'])
|
| 79 |
+
def load_document():
|
| 80 |
+
# ... (rest of your /load_document function remains largely the same) ...
|
| 81 |
+
# Ensure 'embeddings' is properly loaded before this.
|
| 82 |
+
if not embeddings:
|
| 83 |
+
return jsonify({"error": "Embedding model not initialized. Server might be restarting or failed to load models."}), 500
|
| 84 |
+
|
| 85 |
+
data = request.get_json()
|
| 86 |
+
user_id = data.get("user_id")
|
| 87 |
+
text = data.get("text")
|
| 88 |
+
|
| 89 |
+
if not user_id: return jsonify({"error": "User ID (user_id) is required to load a document."}), 400
|
| 90 |
+
if not text: return jsonify({"error": "No text provided to load."}), 400
|
| 91 |
+
|
| 92 |
+
print(f"Loading document for user: {user_id}")
|
| 93 |
+
|
| 94 |
+
try:
|
| 95 |
+
# Create a unique persistence directory for each user's ChromaDB
|
| 96 |
+
# NOTE: On Hugging Face Spaces, this persist_dir will be within the Space's storage,
|
| 97 |
+
# which can be ephemeral or reset, depending on space type/resource usage.
|
| 98 |
+
# For a true persistent solution, you'd need external storage.
|
| 99 |
+
persist_dir = f"./chroma_db_users/{user_id}/"
|
| 100 |
+
os.makedirs(persist_dir, exist_ok=True)
|
| 101 |
+
|
| 102 |
+
base_document = Document(page_content=text, metadata={"user_id": user_id, "source": "user_upload"})
|
| 103 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
|
| 104 |
+
chunks = text_splitter.split_documents([base_document])
|
| 105 |
+
|
| 106 |
+
user_vectorstores[user_id] = Chroma.from_documents(
|
| 107 |
+
chunks, embedding=embeddings, persist_directory=persist_dir
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
print(f"Document loaded for user '{user_id}'. Chunks created: {len(chunks)} at {persist_dir}")
|
| 111 |
+
return jsonify({"message": f"Document loaded successfully for user '{user_id}'.", "chunks_created": len(chunks)})
|
| 112 |
+
except Exception as e:
|
| 113 |
+
print(f"Error loading document for user '{user_id}': {e}")
|
| 114 |
+
return jsonify({"error": f"Error loading document: {e}"}), 500
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
@app.route('/query', methods=['POST'])
|
| 118 |
+
def query():
|
| 119 |
+
"""
|
| 120 |
+
Query the currently loaded document for a specific user to summarize or answer a question.
|
| 121 |
+
"""
|
| 122 |
+
if not llm_pipeline or not embeddings:
|
| 123 |
+
return jsonify({"error": "Models not initialized. Server might be restarting or failed to load models."}), 500
|
| 124 |
+
|
| 125 |
+
data = request.get_json()
|
| 126 |
+
user_id = data.get("user_id")
|
| 127 |
+
query_text = data.get("query")
|
| 128 |
+
|
| 129 |
+
if not user_id: return jsonify({"error": "User ID (user_id) is required to query."}), 400
|
| 130 |
+
if not query_text: return jsonify({"error": "No query text provided."}), 400
|
| 131 |
+
|
| 132 |
+
print(f"Query received for user: {user_id}, Query: '{query_text}'")
|
| 133 |
+
|
| 134 |
+
current_user_vectorstore = user_vectorstores.get(user_id)
|
| 135 |
+
if not current_user_vectorstore:
|
| 136 |
+
user_persist_dir = f"./chroma_db_users/{user_id}/"
|
| 137 |
+
if os.path.exists(user_persist_dir):
|
| 138 |
+
try:
|
| 139 |
+
current_user_vectorstore = Chroma(persist_directory=user_persist_dir, embedding_function=embeddings)
|
| 140 |
+
user_vectorstores[user_id] = current_user_vectorstore
|
| 141 |
+
print(f"Loaded existing vectorstore for user '{user_id}' from disk.")
|
| 142 |
+
except Exception as e:
|
| 143 |
+
print(f"Error loading vectorstore from disk for user '{user_id}': {e}")
|
| 144 |
+
return jsonify({"error": f"Failed to load document for user '{user_id}'. Please try loading it again or check server logs."}), 500
|
| 145 |
+
else:
|
| 146 |
+
return jsonify({"error": f"No document loaded for user '{user_id}'. Please load a document first using /load_document."}), 400
|
| 147 |
+
|
| 148 |
+
try:
|
| 149 |
+
retriever = current_user_vectorstore.as_retriever()
|
| 150 |
+
|
| 151 |
+
prompt_template = ChatPromptTemplate.from_template(
|
| 152 |
+
"""Answer the question based ONLY on the following context. If the answer is not available in the provided context, politely state that you cannot find the answer in the provided information.
|
| 153 |
+
|
| 154 |
+
Context: {context}
|
| 155 |
+
|
| 156 |
+
Question: {question}
|
| 157 |
+
"""
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
# --- RAG Chain for Hugging Face Pipeline ---
|
| 161 |
+
# Get relevant context documents
|
| 162 |
+
retrieved_docs = retriever.invoke(query_text)
|
| 163 |
+
context_text = "\n\n".join([doc.page_content for doc in retrieved_docs])
|
| 164 |
+
|
| 165 |
+
# Format the prompt using the template and retrieved context
|
| 166 |
+
formatted_prompt = prompt_template.format(context=context_text, question=query_text)
|
| 167 |
+
|
| 168 |
+
# Use the Hugging Face pipeline directly for text generation
|
| 169 |
+
# Pass the formatted prompt to the pipeline
|
| 170 |
+
outputs = llm_pipeline(formatted_prompt)
|
| 171 |
+
|
| 172 |
+
# The output from the pipeline needs to be parsed based on its structure
|
| 173 |
+
# It's usually a list of dictionaries, with 'generated_text' key.
|
| 174 |
+
# You might need to refine this parsing based on the exact model's output format.
|
| 175 |
+
generated_text = outputs[0]['generated_text']
|
| 176 |
+
|
| 177 |
+
# The model might repeat the prompt or parts of it, extract only the new response.
|
| 178 |
+
# This is a common challenge with text generation.
|
| 179 |
+
# A simple way is to find the query in the generated text and take what comes after.
|
| 180 |
+
response_start_index = generated_text.find(formatted_prompt)
|
| 181 |
+
if response_start_index != -1:
|
| 182 |
+
response = generated_text[response_start_index + len(formatted_prompt):].strip()
|
| 183 |
+
else:
|
| 184 |
+
response = generated_text.strip() # Fallback if prompt isn't found perfectly
|
| 185 |
+
|
| 186 |
+
print(f"Response generated for user '{user_id}'.")
|
| 187 |
+
return jsonify({"response": response})
|
| 188 |
+
except Exception as e:
|
| 189 |
+
print(f"ERROR: An unexpected error occurred during query for user '{user_id}': {e}")
|
| 190 |
+
import traceback
|
| 191 |
+
traceback.print_exc()
|
| 192 |
+
return jsonify({"error": f"Error processing query: {e}"}), 500
|
| 193 |
+
|
| 194 |
+
if __name__ == "__main__":
|
| 195 |
+
# Call initialization function directly (no Flask debug)
|
| 196 |
+
initialize_models()
|
| 197 |
+
print(f"Starting Flask RAG MVP application on http://0.0.0.0:7860 (Hugging Face Spaces default port)")
|
| 198 |
+
print(f"Using LLM: {LLM_MODEL_NAME_HF}, Embeddings: {EMBEDDING_MODEL_NAME_HF}")
|
| 199 |
+
print("API endpoints:")
|
| 200 |
+
print(" - POST /load_document (Requires 'user_id' and 'text')")
|
| 201 |
+
print(" - POST /query (Requires 'user_id' and 'query')")
|
| 202 |
+
|
| 203 |
+
# Hugging Face Spaces typically runs on port 7860
|
| 204 |
+
app.run(host="0.0.0.0", port=7860)
|