hmm183 commited on
Commit
fffc4f5
·
verified ·
1 Parent(s): 61c5719

Upload 3 files

Browse files
Files changed (3) hide show
  1. Procfile +1 -0
  2. app.py +204 -0
  3. requirements.txt +10 -0
Procfile ADDED
@@ -0,0 +1 @@
 
 
1
+ web: python app.py
app.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify
2
+ from flask_cors import CORS
3
+ import os
4
+ # No requests import needed for Ollama connection check if not using Ollama
5
+
6
+ # Import Hugging Face Transformers
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
8
+ import torch # For checking GPU availability
9
+
10
+ from langchain_community.embeddings import HuggingFaceEmbeddings # Using HF Embeddings now
11
+ from langchain_community.vectorstores import Chroma
12
+ from langchain_core.documents import Document
13
+ from langchain_core.prompts import ChatPromptTemplate
14
+ from langchain_core.output_parsers import StrOutputParser
15
+ from langchain_core.runnables import RunnablePassthrough
16
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
17
+
18
+ app = Flask(__name__)
19
+ CORS(app)
20
+
21
+ # --- Model Configuration for Hugging Face Transformers ---
22
+ # CHOOSE A SMALLER MODEL! Gemma 4B is too large for free tier usually.
23
+ # 'google/gemma-2b-it' is a good conversational starting point.
24
+ LLM_MODEL_NAME_HF = "google/gemma-2b-it"
25
+ EMBEDDING_MODEL_NAME_HF = "sentence-transformers/all-MiniLM-L6-v2" # Standard small embedding model
26
+
27
+ # Global variables for models
28
+ llm_pipeline = None # Will be a Hugging Face pipeline
29
+ embeddings = None # Will be a HuggingFaceEmbeddings instance
30
+
31
+ # --- User-specific Vector Stores Cache ---
32
+ user_vectorstores = {}
33
+
34
+ def initialize_models():
35
+ """
36
+ Initialize Hugging Face models (LLM pipeline and Embeddings).
37
+ """
38
+ global llm_pipeline, embeddings
39
+ print("Initializing Hugging Face models...")
40
+ try:
41
+ # Determine device for LLM: Use GPU if available, otherwise CPU
42
+ device = 0 if torch.cuda.is_available() else -1
43
+ print(f"Using device: {'cuda' if device == 0 else 'cpu'}")
44
+
45
+ # Initialize LLM Pipeline
46
+ # This will download the model weights (gemma-2b-it is ~5GB)
47
+ # It's recommended to do this once at startup.
48
+ print(f"Loading LLM: {LLM_MODEL_NAME_HF}...")
49
+ tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME_HF)
50
+ model = AutoModelForCausalLM.from_pretrained(LLM_MODEL_NAME_HF, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32)
51
+ llm_pipeline = pipeline(
52
+ "text-generation",
53
+ model=model,
54
+ tokenizer=tokenizer,
55
+ max_new_tokens=500, # Limit response length
56
+ device=device,
57
+ # Add other generation parameters as needed, e.g., do_sample=True, top_p=0.9, temperature=0.7
58
+ )
59
+ print("LLM Pipeline initialized successfully!")
60
+
61
+ # Initialize Hugging Face Embeddings
62
+ print(f"Loading Embedding Model: {EMBEDDING_MODEL_NAME_HF}...")
63
+ embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME_HF)
64
+ print("Embedding Model initialized successfully!")
65
+
66
+ except Exception as e:
67
+ print(f"ERROR: An unexpected error occurred during model initialization: {e}")
68
+ llm_pipeline = None
69
+ embeddings = None
70
+ # Raise the exception to prevent the app from starting if models fail to load
71
+ raise e
72
+
73
+ # --- Helper function to adapt HF pipeline to LangChain's LLM interface ---
74
+ # LangChain's pipeline.py can convert HF pipelines but requires some setup.
75
+ # For simplicity, we'll manually wrap it in the RAG chain
76
+ # We will use it directly in the RAG chain's invoke step.
77
+
78
+ @app.route('/load_document', methods=['POST'])
79
+ def load_document():
80
+ # ... (rest of your /load_document function remains largely the same) ...
81
+ # Ensure 'embeddings' is properly loaded before this.
82
+ if not embeddings:
83
+ return jsonify({"error": "Embedding model not initialized. Server might be restarting or failed to load models."}), 500
84
+
85
+ data = request.get_json()
86
+ user_id = data.get("user_id")
87
+ text = data.get("text")
88
+
89
+ if not user_id: return jsonify({"error": "User ID (user_id) is required to load a document."}), 400
90
+ if not text: return jsonify({"error": "No text provided to load."}), 400
91
+
92
+ print(f"Loading document for user: {user_id}")
93
+
94
+ try:
95
+ # Create a unique persistence directory for each user's ChromaDB
96
+ # NOTE: On Hugging Face Spaces, this persist_dir will be within the Space's storage,
97
+ # which can be ephemeral or reset, depending on space type/resource usage.
98
+ # For a true persistent solution, you'd need external storage.
99
+ persist_dir = f"./chroma_db_users/{user_id}/"
100
+ os.makedirs(persist_dir, exist_ok=True)
101
+
102
+ base_document = Document(page_content=text, metadata={"user_id": user_id, "source": "user_upload"})
103
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
104
+ chunks = text_splitter.split_documents([base_document])
105
+
106
+ user_vectorstores[user_id] = Chroma.from_documents(
107
+ chunks, embedding=embeddings, persist_directory=persist_dir
108
+ )
109
+
110
+ print(f"Document loaded for user '{user_id}'. Chunks created: {len(chunks)} at {persist_dir}")
111
+ return jsonify({"message": f"Document loaded successfully for user '{user_id}'.", "chunks_created": len(chunks)})
112
+ except Exception as e:
113
+ print(f"Error loading document for user '{user_id}': {e}")
114
+ return jsonify({"error": f"Error loading document: {e}"}), 500
115
+
116
+
117
+ @app.route('/query', methods=['POST'])
118
+ def query():
119
+ """
120
+ Query the currently loaded document for a specific user to summarize or answer a question.
121
+ """
122
+ if not llm_pipeline or not embeddings:
123
+ return jsonify({"error": "Models not initialized. Server might be restarting or failed to load models."}), 500
124
+
125
+ data = request.get_json()
126
+ user_id = data.get("user_id")
127
+ query_text = data.get("query")
128
+
129
+ if not user_id: return jsonify({"error": "User ID (user_id) is required to query."}), 400
130
+ if not query_text: return jsonify({"error": "No query text provided."}), 400
131
+
132
+ print(f"Query received for user: {user_id}, Query: '{query_text}'")
133
+
134
+ current_user_vectorstore = user_vectorstores.get(user_id)
135
+ if not current_user_vectorstore:
136
+ user_persist_dir = f"./chroma_db_users/{user_id}/"
137
+ if os.path.exists(user_persist_dir):
138
+ try:
139
+ current_user_vectorstore = Chroma(persist_directory=user_persist_dir, embedding_function=embeddings)
140
+ user_vectorstores[user_id] = current_user_vectorstore
141
+ print(f"Loaded existing vectorstore for user '{user_id}' from disk.")
142
+ except Exception as e:
143
+ print(f"Error loading vectorstore from disk for user '{user_id}': {e}")
144
+ return jsonify({"error": f"Failed to load document for user '{user_id}'. Please try loading it again or check server logs."}), 500
145
+ else:
146
+ return jsonify({"error": f"No document loaded for user '{user_id}'. Please load a document first using /load_document."}), 400
147
+
148
+ try:
149
+ retriever = current_user_vectorstore.as_retriever()
150
+
151
+ prompt_template = ChatPromptTemplate.from_template(
152
+ """Answer the question based ONLY on the following context. If the answer is not available in the provided context, politely state that you cannot find the answer in the provided information.
153
+
154
+ Context: {context}
155
+
156
+ Question: {question}
157
+ """
158
+ )
159
+
160
+ # --- RAG Chain for Hugging Face Pipeline ---
161
+ # Get relevant context documents
162
+ retrieved_docs = retriever.invoke(query_text)
163
+ context_text = "\n\n".join([doc.page_content for doc in retrieved_docs])
164
+
165
+ # Format the prompt using the template and retrieved context
166
+ formatted_prompt = prompt_template.format(context=context_text, question=query_text)
167
+
168
+ # Use the Hugging Face pipeline directly for text generation
169
+ # Pass the formatted prompt to the pipeline
170
+ outputs = llm_pipeline(formatted_prompt)
171
+
172
+ # The output from the pipeline needs to be parsed based on its structure
173
+ # It's usually a list of dictionaries, with 'generated_text' key.
174
+ # You might need to refine this parsing based on the exact model's output format.
175
+ generated_text = outputs[0]['generated_text']
176
+
177
+ # The model might repeat the prompt or parts of it, extract only the new response.
178
+ # This is a common challenge with text generation.
179
+ # A simple way is to find the query in the generated text and take what comes after.
180
+ response_start_index = generated_text.find(formatted_prompt)
181
+ if response_start_index != -1:
182
+ response = generated_text[response_start_index + len(formatted_prompt):].strip()
183
+ else:
184
+ response = generated_text.strip() # Fallback if prompt isn't found perfectly
185
+
186
+ print(f"Response generated for user '{user_id}'.")
187
+ return jsonify({"response": response})
188
+ except Exception as e:
189
+ print(f"ERROR: An unexpected error occurred during query for user '{user_id}': {e}")
190
+ import traceback
191
+ traceback.print_exc()
192
+ return jsonify({"error": f"Error processing query: {e}"}), 500
193
+
194
+ if __name__ == "__main__":
195
+ # Call initialization function directly (no Flask debug)
196
+ initialize_models()
197
+ print(f"Starting Flask RAG MVP application on http://0.0.0.0:7860 (Hugging Face Spaces default port)")
198
+ print(f"Using LLM: {LLM_MODEL_NAME_HF}, Embeddings: {EMBEDDING_MODEL_NAME_HF}")
199
+ print("API endpoints:")
200
+ print(" - POST /load_document (Requires 'user_id' and 'text')")
201
+ print(" - POST /query (Requires 'user_id' and 'query')")
202
+
203
+ # Hugging Face Spaces typically runs on port 7860
204
+ app.run(host="0.0.0.0", port=7860)
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ Flask
2
+ Flask-Cors
3
+ transformers
4
+ torch # Needed for transformers, PyTorch will be installed
5
+ accelerate # Often needed for optimizing transformer models
6
+ langchain-community
7
+ langchain-chroma
8
+ langchain-core
9
+ langchain-text-splitters
10
+ sentence-transformers # For HuggingFaceEmbeddings