hmm183 commited on
Commit
f4b962e
·
verified ·
1 Parent(s): 71469fb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -48
app.py CHANGED
@@ -1,16 +1,25 @@
1
- import os # Import os at the top
2
- # Set a writable cache directory for transformers
3
- os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
4
 
 
 
 
 
 
 
 
 
 
5
  from flask import Flask, request, jsonify
6
  from flask_cors import CORS
7
- # No requests import needed for Ollama connection check if not using Ollama
8
 
9
- # Import Hugging Face Transformers
 
 
10
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
11
- import torch
12
 
13
- from langchain_community.embeddings import HuggingFaceEmbeddings
14
  from langchain_community.vectorstores import Chroma
15
  from langchain_core.documents import Document
16
  from langchain_core.prompts import ChatPromptTemplate
@@ -19,94 +28,126 @@ from langchain_core.runnables import RunnablePassthrough
19
  from langchain_text_splitters import RecursiveCharacterTextSplitter
20
 
21
  app = Flask(__name__)
22
- CORS(app)
23
 
24
- # ... (rest of your app.py code) ...
25
  # --- Model Configuration for Hugging Face Transformers ---
26
- # CHOOSE A SMALLER MODEL! Gemma 4B is too large for free tier usually.
27
- # 'google/gemma-2b-it' is a good conversational starting point.
 
28
  LLM_MODEL_NAME_HF = "google/gemma-2b-it"
29
- EMBEDDING_MODEL_NAME_HF = "sentence-transformers/all-MiniLM-L6-v2" # Standard small embedding model
30
 
31
  # Global variables for models
32
- llm_pipeline = None # Will be a Hugging Face pipeline
33
- embeddings = None # Will be a HuggingFaceEmbeddings instance
34
 
35
  # --- User-specific Vector Stores Cache ---
 
 
 
 
 
36
  user_vectorstores = {}
37
 
38
  def initialize_models():
39
  """
40
  Initialize Hugging Face models (LLM pipeline and Embeddings).
 
41
  """
42
  global llm_pipeline, embeddings
43
  print("Initializing Hugging Face models...")
44
  try:
45
  # Determine device for LLM: Use GPU if available, otherwise CPU
 
46
  device = 0 if torch.cuda.is_available() else -1
47
- print(f"Using device: {'cuda' if device == 0 else 'cpu'}")
48
 
49
- # Initialize LLM Pipeline
50
- # This will download the model weights (gemma-2b-it is ~5GB)
51
- # It's recommended to do this once at startup.
52
  print(f"Loading LLM: {LLM_MODEL_NAME_HF}...")
 
53
  tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME_HF)
54
- model = AutoModelForCausalLM.from_pretrained(LLM_MODEL_NAME_HF, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32)
 
 
 
 
55
  llm_pipeline = pipeline(
56
  "text-generation",
57
  model=model,
58
  tokenizer=tokenizer,
59
- max_new_tokens=500, # Limit response length
60
- device=device,
61
- # Add other generation parameters as needed, e.g., do_sample=True, top_p=0.9, temperature=0.7
 
 
 
 
 
 
 
62
  )
63
  print("LLM Pipeline initialized successfully!")
64
 
65
- # Initialize Hugging Face Embeddings
66
  print(f"Loading Embedding Model: {EMBEDDING_MODEL_NAME_HF}...")
67
- embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME_HF)
 
 
 
 
 
 
 
 
68
  print("Embedding Model initialized successfully!")
69
 
70
  except Exception as e:
71
  print(f"ERROR: An unexpected error occurred during model initialization: {e}")
72
  llm_pipeline = None
73
  embeddings = None
74
- # Raise the exception to prevent the app from starting if models fail to load
75
  raise e
76
 
77
- # --- Helper function to adapt HF pipeline to LangChain's LLM interface ---
78
- # LangChain's pipeline.py can convert HF pipelines but requires some setup.
79
- # For simplicity, we'll manually wrap it in the RAG chain
80
- # We will use it directly in the RAG chain's invoke step.
81
-
82
  @app.route('/load_document', methods=['POST'])
83
  def load_document():
84
- # ... (rest of your /load_document function remains largely the same) ...
85
- # Ensure 'embeddings' is properly loaded before this.
 
 
86
  if not embeddings:
87
  return jsonify({"error": "Embedding model not initialized. Server might be restarting or failed to load models."}), 500
88
 
89
  data = request.get_json()
90
- user_id = data.get("user_id")
91
  text = data.get("text")
92
 
93
- if not user_id: return jsonify({"error": "User ID (user_id) is required to load a document."}), 400
94
- if not text: return jsonify({"error": "No text provided to load."}), 400
 
 
95
 
96
  print(f"Loading document for user: {user_id}")
97
 
98
  try:
99
  # Create a unique persistence directory for each user's ChromaDB
100
- # NOTE: On Hugging Face Spaces, this persist_dir will be within the Space's storage,
101
- # which can be ephemeral or reset, depending on space type/resource usage.
102
- # For a true persistent solution, you'd need external storage.
103
- persist_dir = f"./chroma_db_users/{user_id}/"
104
  os.makedirs(persist_dir, exist_ok=True)
105
 
 
106
  base_document = Document(page_content=text, metadata={"user_id": user_id, "source": "user_upload"})
107
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
 
 
 
 
 
 
 
108
  chunks = text_splitter.split_documents([base_document])
109
 
 
 
110
  user_vectorstores[user_id] = Chroma.from_documents(
111
  chunks, embedding=embeddings, persist_directory=persist_dir
112
  )
@@ -115,9 +156,10 @@ def load_document():
115
  return jsonify({"message": f"Document loaded successfully for user '{user_id}'.", "chunks_created": len(chunks)})
116
  except Exception as e:
117
  print(f"Error loading document for user '{user_id}': {e}")
 
 
118
  return jsonify({"error": f"Error loading document: {e}"}), 500
119
 
120
-
121
  @app.route('/query', methods=['POST'])
122
  def query():
123
  """
@@ -130,18 +172,24 @@ def query():
130
  user_id = data.get("user_id")
131
  query_text = data.get("query")
132
 
133
- if not user_id: return jsonify({"error": "User ID (user_id) is required to query."}), 400
134
- if not query_text: return jsonify({"error": "No query text provided."}), 400
 
 
135
 
136
  print(f"Query received for user: {user_id}, Query: '{query_text}'")
137
 
 
138
  current_user_vectorstore = user_vectorstores.get(user_id)
 
 
139
  if not current_user_vectorstore:
140
- user_persist_dir = f"./chroma_db_users/{user_id}/"
141
  if os.path.exists(user_persist_dir):
142
  try:
 
143
  current_user_vectorstore = Chroma(persist_directory=user_persist_dir, embedding_function=embeddings)
144
- user_vectorstores[user_id] = current_user_vectorstore
145
  print(f"Loaded existing vectorstore for user '{user_id}' from disk.")
146
  except Exception as e:
147
  print(f"Error loading vectorstore from disk for user '{user_id}': {e}")
@@ -152,6 +200,7 @@ def query():
152
  try:
153
  retriever = current_user_vectorstore.as_retriever()
154
 
 
155
  prompt_template = ChatPromptTemplate.from_template(
156
  """Answer the question based ONLY on the following context. If the answer is not available in the provided context, politely state that you cannot find the answer in the provided information.
157
 
@@ -170,12 +219,10 @@ Question: {question}
170
  formatted_prompt = prompt_template.format(context=context_text, question=query_text)
171
 
172
  # Use the Hugging Face pipeline directly for text generation
173
- # Pass the formatted prompt to the pipeline
174
  outputs = llm_pipeline(formatted_prompt)
175
 
176
  # The output from the pipeline needs to be parsed based on its structure
177
  # It's usually a list of dictionaries, with 'generated_text' key.
178
- # You might need to refine this parsing based on the exact model's output format.
179
  generated_text = outputs[0]['generated_text']
180
 
181
  # The model might repeat the prompt or parts of it, extract only the new response.
@@ -187,6 +234,17 @@ Question: {question}
187
  else:
188
  response = generated_text.strip() # Fallback if prompt isn't found perfectly
189
 
 
 
 
 
 
 
 
 
 
 
 
190
  print(f"Response generated for user '{user_id}'.")
191
  return jsonify({"response": response})
192
  except Exception as e:
@@ -205,4 +263,5 @@ if __name__ == "__main__":
205
  print(" - POST /query (Requires 'user_id' and 'query')")
206
 
207
  # Hugging Face Spaces typically runs on port 7860
208
- app.run(host="0.0.0.0", port=7860)
 
 
1
+ import os
2
+ import requests # Used for checking Ollama connection in a commented-out section, can be removed if not needed.
 
3
 
4
+ # --- IMPORTANT: Set a writable cache directory for Hugging Face models ---
5
+ # This is crucial for environments like Hugging Face Spaces where default cache locations
6
+ # might not be writable or persistent. /tmp is usually writable.
7
+ # HF_HOME is the preferred environment variable for Hugging Face cache.
8
+ os.environ["HF_HOME"] = "/tmp/huggingface_cache"
9
+ # Ensure the directory exists
10
+ os.makedirs(os.environ["HF_HOME"], exist_ok=True)
11
+
12
+ # --- Flask and CORS ---
13
  from flask import Flask, request, jsonify
14
  from flask_cors import CORS
 
15
 
16
+ # --- LangChain and Hugging Face Libraries ---
17
+ # Note: We are NOT using Ollama directly in this app.py for Hugging Face Spaces.
18
+ # Instead, we are loading models directly via Hugging Face's transformers library.
19
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
20
+ import torch # For checking GPU availability and model dtype
21
 
22
+ from langchain_community.embeddings import HuggingFaceEmbeddings
23
  from langchain_community.vectorstores import Chroma
24
  from langchain_core.documents import Document
25
  from langchain_core.prompts import ChatPromptTemplate
 
28
  from langchain_text_splitters import RecursiveCharacterTextSplitter
29
 
30
  app = Flask(__name__)
31
+ CORS(app) # Allow all origins (good for development/MVP on HF Spaces)
32
 
 
33
  # --- Model Configuration for Hugging Face Transformers ---
34
+ # These models will be downloaded directly by the 'transformers' library.
35
+ # 'google/gemma-2b-it' is chosen for its size and instruction-following capabilities.
36
+ # 'sentence-transformers/all-MiniLM-L6-v2' is a small, efficient embedding model.
37
  LLM_MODEL_NAME_HF = "google/gemma-2b-it"
38
+ EMBEDDING_MODEL_NAME_HF = "sentence-transformers/all-MiniLM-L6-v2"
39
 
40
  # Global variables for models
41
+ llm_pipeline = None # Will hold the Hugging Face text-generation pipeline
42
+ embeddings = None # Will hold the HuggingFaceEmbeddings instance
43
 
44
  # --- User-specific Vector Stores Cache ---
45
+ # This dictionary will hold Chroma instances, keyed by user_id.
46
+ # IMPORTANT MVP LIMITATION: This is an in-memory cache.
47
+ # - If the app restarts, all loaded user contexts are lost from memory (though
48
+ # Chroma data is saved to disk in `chroma_db_users`).
49
+ # - For true concurrency and persistence, you'd load from disk on demand or use an external DB.
50
  user_vectorstores = {}
51
 
52
  def initialize_models():
53
  """
54
  Initialize Hugging Face models (LLM pipeline and Embeddings).
55
+ This function is called once when the Flask app starts.
56
  """
57
  global llm_pipeline, embeddings
58
  print("Initializing Hugging Face models...")
59
  try:
60
  # Determine device for LLM: Use GPU if available, otherwise CPU
61
+ # On Hugging Face Spaces free tier, it's usually CPU (-1).
62
  device = 0 if torch.cuda.is_available() else -1
63
+ print(f"Using device for LLM: {'cuda' if device == 0 else 'cpu'}")
64
 
65
+ # --- Initialize LLM Pipeline (google/gemma-2b-it) ---
 
 
66
  print(f"Loading LLM: {LLM_MODEL_NAME_HF}...")
67
+ # AutoTokenizer and AutoModelForCausalLM will use HF_HOME for caching.
68
  tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME_HF)
69
+ # Use bfloat16 for GPU if available to save memory, otherwise float32 for CPU.
70
+ model = AutoModelForCausalLM.from_pretrained(
71
+ LLM_MODEL_NAME_HF,
72
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
73
+ )
74
  llm_pipeline = pipeline(
75
  "text-generation",
76
  model=model,
77
  tokenizer=tokenizer,
78
+ max_new_tokens=500, # Max tokens for the generated response
79
+ device=device, # Use the determined device (CPU or GPU)
80
+ do_sample=True, # Enable sampling for more varied responses
81
+ temperature=0.7, # Control randomness (lower for more focused, higher for more creative)
82
+ top_p=0.9, # Nucleus sampling
83
+ top_k=50, # Top-k sampling
84
+ # Stop sequences for generation to prevent model from continuing beyond the answer
85
+ # These are crucial for chat models used in RAG.
86
+ eos_token_id=tokenizer.eos_token_id, # End of sequence token
87
+ pad_token_id=tokenizer.pad_token_id # Pad token ID
88
  )
89
  print("LLM Pipeline initialized successfully!")
90
 
91
+ # --- Initialize Hugging Face Embeddings (all-MiniLM-L6-v2) ---
92
  print(f"Loading Embedding Model: {EMBEDDING_MODEL_NAME_HF}...")
93
+ embeddings = HuggingFaceEmbeddings(
94
+ model_name=EMBEDDING_MODEL_NAME_HF,
95
+ # Explicitly set cache_folder to ensure it uses the writable directory
96
+ cache_folder=os.environ["HF_HOME"],
97
+ # IMPORTANT: local_files_only=True means it will NOT try to download if not found.
98
+ # If you want it to download if not present, remove this line or set to False.
99
+ # For robust deployment, pre-caching and uploading the model is recommended.
100
+ model_kwargs={"local_files_only": False} # Set to False to allow download if not cached
101
+ )
102
  print("Embedding Model initialized successfully!")
103
 
104
  except Exception as e:
105
  print(f"ERROR: An unexpected error occurred during model initialization: {e}")
106
  llm_pipeline = None
107
  embeddings = None
108
+ # Re-raise the exception to prevent the Flask app from starting if models fail to load
109
  raise e
110
 
 
 
 
 
 
111
  @app.route('/load_document', methods=['POST'])
112
  def load_document():
113
+ """
114
+ Load a document for a specific user into their dedicated persistent vector store.
115
+ The text is chunked for better retrieval.
116
+ """
117
  if not embeddings:
118
  return jsonify({"error": "Embedding model not initialized. Server might be restarting or failed to load models."}), 500
119
 
120
  data = request.get_json()
121
+ user_id = data.get("user_id") # Expecting a user_id from the client
122
  text = data.get("text")
123
 
124
+ if not user_id:
125
+ return jsonify({"error": "User ID (user_id) is required to load a document."}), 400
126
+ if not text:
127
+ return jsonify({"error": "No text provided to load."}), 400
128
 
129
  print(f"Loading document for user: {user_id}")
130
 
131
  try:
132
  # Create a unique persistence directory for each user's ChromaDB
133
+ # This will be within the Space's storage, which can be ephemeral on restarts.
134
+ persist_dir = f"{os.environ['HF_HOME']}/chroma_db_users/{user_id}/"
 
 
135
  os.makedirs(persist_dir, exist_ok=True)
136
 
137
+ # Wrap the input text in a LangChain Document
138
  base_document = Document(page_content=text, metadata={"user_id": user_id, "source": "user_upload"})
139
+
140
+ # Chunk the document for better retrieval performance
141
+ text_splitter = RecursiveCharacterTextSplitter(
142
+ chunk_size=1000, # Max characters per chunk
143
+ chunk_overlap=200, # Overlap between chunks to maintain context
144
+ length_function=len,
145
+ is_separator_regex=False,
146
+ )
147
  chunks = text_splitter.split_documents([base_document])
148
 
149
+ # Create/overwrite the vector store for this specific user
150
+ # This will save to the user-specific directory on disk.
151
  user_vectorstores[user_id] = Chroma.from_documents(
152
  chunks, embedding=embeddings, persist_directory=persist_dir
153
  )
 
156
  return jsonify({"message": f"Document loaded successfully for user '{user_id}'.", "chunks_created": len(chunks)})
157
  except Exception as e:
158
  print(f"Error loading document for user '{user_id}': {e}")
159
+ import traceback
160
+ traceback.print_exc() # Print full traceback for debugging
161
  return jsonify({"error": f"Error loading document: {e}"}), 500
162
 
 
163
  @app.route('/query', methods=['POST'])
164
  def query():
165
  """
 
172
  user_id = data.get("user_id")
173
  query_text = data.get("query")
174
 
175
+ if not user_id:
176
+ return jsonify({"error": "User ID (user_id) is required to query."}), 400
177
+ if not query_text:
178
+ return jsonify({"error": "No query text provided."}), 400
179
 
180
  print(f"Query received for user: {user_id}, Query: '{query_text}'")
181
 
182
+ # Retrieve the vector store for this specific user from the cache
183
  current_user_vectorstore = user_vectorstores.get(user_id)
184
+
185
+ # If not in memory, attempt to load from disk for this user
186
  if not current_user_vectorstore:
187
+ user_persist_dir = f"{os.environ['HF_HOME']}/chroma_db_users/{user_id}/"
188
  if os.path.exists(user_persist_dir):
189
  try:
190
+ # Load the existing vectorstore from disk
191
  current_user_vectorstore = Chroma(persist_directory=user_persist_dir, embedding_function=embeddings)
192
+ user_vectorstores[user_id] = current_user_vectorstore # Cache it in memory for subsequent queries
193
  print(f"Loaded existing vectorstore for user '{user_id}' from disk.")
194
  except Exception as e:
195
  print(f"Error loading vectorstore from disk for user '{user_id}': {e}")
 
200
  try:
201
  retriever = current_user_vectorstore.as_retriever()
202
 
203
+ # Create a prompt template geared toward Q&A based on context
204
  prompt_template = ChatPromptTemplate.from_template(
205
  """Answer the question based ONLY on the following context. If the answer is not available in the provided context, politely state that you cannot find the answer in the provided information.
206
 
 
219
  formatted_prompt = prompt_template.format(context=context_text, question=query_text)
220
 
221
  # Use the Hugging Face pipeline directly for text generation
 
222
  outputs = llm_pipeline(formatted_prompt)
223
 
224
  # The output from the pipeline needs to be parsed based on its structure
225
  # It's usually a list of dictionaries, with 'generated_text' key.
 
226
  generated_text = outputs[0]['generated_text']
227
 
228
  # The model might repeat the prompt or parts of it, extract only the new response.
 
234
  else:
235
  response = generated_text.strip() # Fallback if prompt isn't found perfectly
236
 
237
+ # Further clean-up to remove any trailing prompt parts the model might generate
238
+ if response.startswith("Summary:"):
239
+ response = response[len("Summary:"):].strip()
240
+ if response.startswith("Answer:"):
241
+ response = response[len("Answer:"):].strip()
242
+ if response.startswith("Question:"):
243
+ response = response[len("Question:"):].strip()
244
+ if response.startswith("Context:"):
245
+ response = response[len("Context:"):].strip()
246
+
247
+
248
  print(f"Response generated for user '{user_id}'.")
249
  return jsonify({"response": response})
250
  except Exception as e:
 
263
  print(" - POST /query (Requires 'user_id' and 'query')")
264
 
265
  # Hugging Face Spaces typically runs on port 7860
266
+ app.run(host="0.0.0.0", port=7860)
267
+