hmm183 commited on
Commit
b44200d
·
verified ·
1 Parent(s): 6abeebb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -254
app.py CHANGED
@@ -1,267 +1,104 @@
 
1
  import os
2
- import requests # Used for checking Ollama connection in a commented-out section, can be removed if not needed.
3
-
4
- # --- IMPORTANT: Set a writable cache directory for Hugging Face models ---
5
- # This is crucial for environments like Hugging Face Spaces where default cache locations
6
- # might not be writable or persistent. /tmp is usually writable.
7
- # HF_HOME is the preferred environment variable for Hugging Face cache.
8
- os.environ["HF_HOME"] = "/tmp/huggingface_cache"
9
- # Ensure the directory exists
10
- os.makedirs(os.environ["HF_HOME"], exist_ok=True)
11
-
12
- # --- Flask and CORS ---
13
- from flask import Flask, request, jsonify
14
- from flask_cors import CORS
15
-
16
- # --- LangChain and Hugging Face Libraries ---
17
- # Note: We are NOT using Ollama directly in this app.py for Hugging Face Spaces.
18
- # Instead, we are loading models directly via Hugging Face's transformers library.
19
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
20
- import torch # For checking GPU availability and model dtype
21
-
22
  from langchain_community.embeddings import HuggingFaceEmbeddings
23
  from langchain_community.vectorstores import Chroma
24
  from langchain_core.documents import Document
25
- from langchain_core.prompts import ChatPromptTemplate
26
- from langchain_core.output_parsers import StrOutputParser
27
- from langchain_core.runnables import RunnablePassthrough
28
  from langchain_text_splitters import RecursiveCharacterTextSplitter
29
-
30
- app = Flask(__name__)
31
- CORS(app) # Allow all origins (good for development/MVP on HF Spaces)
32
-
33
- # --- Model Configuration for Hugging Face Transformers ---
34
- # These models will be downloaded directly by the 'transformers' library.
35
- # 'google/gemma-2b-it' is chosen for its size and instruction-following capabilities.
36
- # 'sentence-transformers/all-MiniLM-L6-v2' is a small, efficient embedding model.
37
- LLM_MODEL_NAME_HF = "google/gemma-2b-it"
38
- EMBEDDING_MODEL_NAME_HF = "sentence-transformers/all-MiniLM-L6-v2"
39
-
40
- # Global variables for models
41
- llm_pipeline = None # Will hold the Hugging Face text-generation pipeline
42
- embeddings = None # Will hold the HuggingFaceEmbeddings instance
43
-
44
- # --- User-specific Vector Stores Cache ---
45
- # This dictionary will hold Chroma instances, keyed by user_id.
46
- # IMPORTANT MVP LIMITATION: This is an in-memory cache.
47
- # - If the app restarts, all loaded user contexts are lost from memory (though
48
- # Chroma data is saved to disk in `chroma_db_users`).
49
- # - For true concurrency and persistence, you'd load from disk on demand or use an external DB.
50
- user_vectorstores = {}
51
-
52
- def initialize_models():
53
- """
54
- Initialize Hugging Face models (LLM pipeline and Embeddings).
55
- This function is called once when the Flask app starts.
56
- """
57
- global llm_pipeline, embeddings
58
- print("Initializing Hugging Face models...")
59
- try:
60
- # Determine device for LLM: Use GPU if available, otherwise CPU
61
- # On Hugging Face Spaces free tier, it's usually CPU (-1).
62
- device = 0 if torch.cuda.is_available() else -1
63
- print(f"Using device for LLM: {'cuda' if device == 0 else 'cpu'}")
64
-
65
- # --- Initialize LLM Pipeline (google/gemma-2b-it) ---
66
- print(f"Loading LLM: {LLM_MODEL_NAME_HF}...")
67
- # AutoTokenizer and AutoModelForCausalLM will use HF_HOME for caching.
68
- tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME_HF)
69
- # Use bfloat16 for GPU if available to save memory, otherwise float32 for CPU.
70
- model = AutoModelForCausalLM.from_pretrained(
71
- LLM_MODEL_NAME_HF,
72
- torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
73
- )
74
- llm_pipeline = pipeline(
75
- "text-generation",
76
- model=model,
77
- tokenizer=tokenizer,
78
- max_new_tokens=500, # Max tokens for the generated response
79
- device=device, # Use the determined device (CPU or GPU)
80
- do_sample=True, # Enable sampling for more varied responses
81
- temperature=0.7, # Control randomness (lower for more focused, higher for more creative)
82
- top_p=0.9, # Nucleus sampling
83
- top_k=50, # Top-k sampling
84
- # Stop sequences for generation to prevent model from continuing beyond the answer
85
- # These are crucial for chat models used in RAG.
86
- eos_token_id=tokenizer.eos_token_id, # End of sequence token
87
- pad_token_id=tokenizer.pad_token_id # Pad token ID
88
- )
89
- print("LLM Pipeline initialized successfully!")
90
-
91
- # --- Initialize Hugging Face Embeddings (all-MiniLM-L6-v2) ---
92
- print(f"Loading Embedding Model: {EMBEDDING_MODEL_NAME_HF}...")
93
- embeddings = HuggingFaceEmbeddings(
94
- model_name=EMBEDDING_MODEL_NAME_HF,
95
- # Explicitly set cache_folder to ensure it uses the writable directory
96
- cache_folder=os.environ["HF_HOME"],
97
- # IMPORTANT: local_files_only=True means it will NOT try to download if not found.
98
- # If you want it to download if not present, remove this line or set to False.
99
- # For robust deployment, pre-caching and uploading the model is recommended.
100
- model_kwargs={"local_files_only": False} # Set to False to allow download if not cached
101
- )
102
- print("Embedding Model initialized successfully!")
103
-
104
- except Exception as e:
105
- print(f"ERROR: An unexpected error occurred during model initialization: {e}")
106
- llm_pipeline = None
107
- embeddings = None
108
- # Re-raise the exception to prevent the Flask app from starting if models fail to load
109
- raise e
110
-
111
- @app.route('/load_document', methods=['POST'])
112
- def load_document():
113
- """
114
- Load a document for a specific user into their dedicated persistent vector store.
115
- The text is chunked for better retrieval.
116
- """
117
- if not embeddings:
118
- return jsonify({"error": "Embedding model not initialized. Server might be restarting or failed to load models."}), 500
119
-
120
- data = request.get_json()
121
- user_id = data.get("user_id") # Expecting a user_id from the client
122
- text = data.get("text")
123
-
124
- if not user_id:
125
- return jsonify({"error": "User ID (user_id) is required to load a document."}), 400
126
- if not text:
127
- return jsonify({"error": "No text provided to load."}), 400
128
-
129
- print(f"Loading document for user: {user_id}")
130
-
131
- try:
132
- # Create a unique persistence directory for each user's ChromaDB
133
- # This will be within the Space's storage, which can be ephemeral on restarts.
134
- persist_dir = f"{os.environ['HF_HOME']}/chroma_db_users/{user_id}/"
135
- os.makedirs(persist_dir, exist_ok=True)
136
-
137
- # Wrap the input text in a LangChain Document
138
- base_document = Document(page_content=text, metadata={"user_id": user_id, "source": "user_upload"})
139
-
140
- # Chunk the document for better retrieval performance
141
- text_splitter = RecursiveCharacterTextSplitter(
142
- chunk_size=1000, # Max characters per chunk
143
- chunk_overlap=200, # Overlap between chunks to maintain context
144
- length_function=len,
145
- is_separator_regex=False,
146
- )
147
- chunks = text_splitter.split_documents([base_document])
148
-
149
- # Create/overwrite the vector store for this specific user
150
- # This will save to the user-specific directory on disk.
151
- user_vectorstores[user_id] = Chroma.from_documents(
152
- chunks, embedding=embeddings, persist_directory=persist_dir
153
- )
154
-
155
- print(f"Document loaded for user '{user_id}'. Chunks created: {len(chunks)} at {persist_dir}")
156
- return jsonify({"message": f"Document loaded successfully for user '{user_id}'.", "chunks_created": len(chunks)})
157
- except Exception as e:
158
- print(f"Error loading document for user '{user_id}': {e}")
159
- import traceback
160
- traceback.print_exc() # Print full traceback for debugging
161
- return jsonify({"error": f"Error loading document: {e}"}), 500
162
-
163
- @app.route('/query', methods=['POST'])
164
- def query():
165
- """
166
- Query the currently loaded document for a specific user to summarize or answer a question.
167
- """
168
- if not llm_pipeline or not embeddings:
169
- return jsonify({"error": "Models not initialized. Server might be restarting or failed to load models."}), 500
170
-
171
- data = request.get_json()
172
- user_id = data.get("user_id")
173
- query_text = data.get("query")
174
-
175
- if not user_id:
176
- return jsonify({"error": "User ID (user_id) is required to query."}), 400
177
- if not query_text:
178
- return jsonify({"error": "No query text provided."}), 400
179
-
180
- print(f"Query received for user: {user_id}, Query: '{query_text}'")
181
-
182
- # Retrieve the vector store for this specific user from the cache
183
- current_user_vectorstore = user_vectorstores.get(user_id)
184
-
185
- # If not in memory, attempt to load from disk for this user
186
- if not current_user_vectorstore:
187
- user_persist_dir = f"{os.environ['HF_HOME']}/chroma_db_users/{user_id}/"
188
- if os.path.exists(user_persist_dir):
189
- try:
190
- # Load the existing vectorstore from disk
191
- current_user_vectorstore = Chroma(persist_directory=user_persist_dir, embedding_function=embeddings)
192
- user_vectorstores[user_id] = current_user_vectorstore # Cache it in memory for subsequent queries
193
- print(f"Loaded existing vectorstore for user '{user_id}' from disk.")
194
- except Exception as e:
195
- print(f"Error loading vectorstore from disk for user '{user_id}': {e}")
196
- return jsonify({"error": f"Failed to load document for user '{user_id}'. Please try loading it again or check server logs."}), 500
197
  else:
198
- return jsonify({"error": f"No document loaded for user '{user_id}'. Please load a document first using /load_document."}), 400
199
 
200
- try:
201
- retriever = current_user_vectorstore.as_retriever()
202
-
203
- # Create a prompt template geared toward Q&A based on context
204
- prompt_template = ChatPromptTemplate.from_template(
205
- """Answer the question based ONLY on the following context. If the answer is not available in the provided context, politely state that you cannot find the answer in the provided information.
206
 
 
 
 
207
  Context: {context}
 
 
 
208
 
209
- Question: {question}
210
- """
211
- )
212
-
213
- # --- RAG Chain for Hugging Face Pipeline ---
214
- # Get relevant context documents
215
- retrieved_docs = retriever.invoke(query_text)
216
- context_text = "\n\n".join([doc.page_content for doc in retrieved_docs])
217
-
218
- # Format the prompt using the template and retrieved context
219
- formatted_prompt = prompt_template.format(context=context_text, question=query_text)
220
-
221
- # Use the Hugging Face pipeline directly for text generation
222
- outputs = llm_pipeline(formatted_prompt)
223
-
224
- # The output from the pipeline needs to be parsed based on its structure
225
- # It's usually a list of dictionaries, with 'generated_text' key.
226
- generated_text = outputs[0]['generated_text']
227
 
228
- # The model might repeat the prompt or parts of it, extract only the new response.
229
- # This is a common challenge with text generation.
230
- # A simple way is to find the query in the generated text and take what comes after.
231
- response_start_index = generated_text.find(formatted_prompt)
232
- if response_start_index != -1:
233
- response = generated_text[response_start_index + len(formatted_prompt):].strip()
234
- else:
235
- response = generated_text.strip() # Fallback if prompt isn't found perfectly
236
-
237
- # Further clean-up to remove any trailing prompt parts the model might generate
238
- if response.startswith("Summary:"):
239
- response = response[len("Summary:"):].strip()
240
- if response.startswith("Answer:"):
241
- response = response[len("Answer:"):].strip()
242
- if response.startswith("Question:"):
243
- response = response[len("Question:"):].strip()
244
- if response.startswith("Context:"):
245
- response = response[len("Context:"):].strip()
246
 
247
-
248
- print(f"Response generated for user '{user_id}'.")
249
- return jsonify({"response": response})
250
- except Exception as e:
251
- print(f"ERROR: An unexpected error occurred during query for user '{user_id}': {e}")
252
- import traceback
253
- traceback.print_exc()
254
- return jsonify({"error": f"Error processing query: {e}"}), 500
255
-
256
  if __name__ == "__main__":
257
- # Call initialization function directly (no Flask debug)
258
- initialize_models()
259
- print(f"Starting Flask RAG MVP application on http://0.0.0.0:7860 (Hugging Face Spaces default port)")
260
- print(f"Using LLM: {LLM_MODEL_NAME_HF}, Embeddings: {EMBEDDING_MODEL_NAME_HF}")
261
- print("API endpoints:")
262
- print(" - POST /load_document (Requires 'user_id' and 'text')")
263
- print(" - POST /query (Requires 'user_id' and 'query')")
264
-
265
- # Hugging Face Spaces typically runs on port 7860
266
- app.run(host="0.0.0.0", port=7860)
267
-
 
1
+ # fastapi_app.py
2
  import os
3
+ from fastapi import FastAPI, Request
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ from pydantic import BaseModel
6
+ import uvicorn
7
+ from typing import Dict
8
+ import torch
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  from langchain_community.embeddings import HuggingFaceEmbeddings
11
  from langchain_community.vectorstores import Chroma
12
  from langchain_core.documents import Document
 
 
 
13
  from langchain_text_splitters import RecursiveCharacterTextSplitter
14
+ from langchain_core.prompts import ChatPromptTemplate
15
+ import asyncio
16
+
17
+ # Set HF cache path
18
+ os.environ["TRANSFORMERS_CACHE"] = "./hf_cache"
19
+
20
+ app = FastAPI()
21
+
22
+ app.add_middleware(
23
+ CORSMiddleware,
24
+ allow_origins=["*"],
25
+ allow_credentials=True,
26
+ allow_methods=["*"],
27
+ allow_headers=["*"],
28
+ )
29
+
30
+ # -----------------------------
31
+ # Load models on startup
32
+ # -----------------------------
33
+ LLM_MODEL_NAME = "google/flan-t5-small" # Lightweight and fast on CPU
34
+ EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
35
+
36
+ llm_model = None
37
+ llm_tokenizer = None
38
+ embeddings = None
39
+ user_vectorstores: Dict[str, Chroma] = {}
40
+
41
+ class LoadDocRequest(BaseModel):
42
+ user_id: str
43
+ text: str
44
+
45
+ class QueryRequest(BaseModel):
46
+ user_id: str
47
+ query: str
48
+
49
+ @app.on_event("startup")
50
+ async def load_models():
51
+ global llm_model, llm_tokenizer, embeddings
52
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
53
+ llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME)
54
+ llm_model = AutoModelForCausalLM.from_pretrained(LLM_MODEL_NAME).to(device)
55
+ embeddings = HuggingFaceEmbeddings(model_name=EMBED_MODEL_NAME)
56
+
57
+ @app.post("/load_document")
58
+ async def load_document(data: LoadDocRequest):
59
+ user_id = data.user_id
60
+ text = data.text
61
+
62
+ persist_dir = f"./chroma_db_users/{user_id}/"
63
+ os.makedirs(persist_dir, exist_ok=True)
64
+
65
+ base_document = Document(page_content=text, metadata={"source": "upload"})
66
+ splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
67
+ chunks = splitter.split_documents([base_document])
68
+
69
+ vectorstore = Chroma.from_documents(chunks, embedding=embeddings, persist_directory=persist_dir)
70
+ user_vectorstores[user_id] = vectorstore
71
+ return {"message": f"Loaded {len(chunks)} chunks for user {user_id}"}
72
+
73
+ @app.post("/query")
74
+ async def query(data: QueryRequest):
75
+ user_id = data.user_id
76
+ query_text = data.query
77
+
78
+ if user_id not in user_vectorstores:
79
+ persist_dir = f"./chroma_db_users/{user_id}/"
80
+ if os.path.exists(persist_dir):
81
+ user_vectorstores[user_id] = Chroma(persist_directory=persist_dir, embedding_function=embeddings)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  else:
83
+ return {"error": f"No vectorstore found for user {user_id}"}
84
 
85
+ vectorstore = user_vectorstores[user_id]
86
+ retriever = vectorstore.as_retriever()
87
+ docs = retriever.invoke(query_text)
 
 
 
88
 
89
+ context = "\n\n".join(doc.page_content for doc in docs)
90
+ prompt_template = ChatPromptTemplate.from_template(
91
+ """Answer the question based ONLY on the context below:
92
  Context: {context}
93
+ Question: {question}"""
94
+ )
95
+ prompt = prompt_template.format(context=context, question=query_text)
96
 
97
+ input_ids = llm_tokenizer(prompt, return_tensors="pt").input_ids.to(llm_model.device)
98
+ output_ids = llm_model.generate(input_ids, max_new_tokens=200)
99
+ response = llm_tokenizer.decode(output_ids[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
+ return {"response": response.replace(prompt, "").strip()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
 
 
 
 
 
 
 
 
 
103
  if __name__ == "__main__":
104
+ uvicorn.run("fastapi_app:app", host="0.0.0.0", port=7860, reload=True)