huynguyen6906 commited on
Commit
d43ecac
·
verified ·
1 Parent(s): fff5f84

Update server_paper_RAM_optimize.py

Browse files
Files changed (1) hide show
  1. server_paper_RAM_optimize.py +29 -152
server_paper_RAM_optimize.py CHANGED
@@ -7,41 +7,44 @@ from flask_cors import CORS
7
  from sentence_transformers import SentenceTransformer
8
  import PyPDF2
9
  import io
 
10
 
11
  app = Flask(__name__)
12
  CORS(app, origins=['*'])
13
- H5_FILE_PATH='Papers_Embedbed_0-1000000.h5'
14
- BIN_FILE_PATH='hnsw_paper_index.bin'
15
- os.environ['H5_FILE_PATH'] = H5_FILE_PATH
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  class PaperSearchEngine:
18
  def __init__(self, h5_file_path=H5_FILE_PATH):
19
- """
20
- Initialize the Paper Search Engine with Sentence Transformers and HNSW index.
21
-
22
- Args:
23
- h5_file_path: Path to the HDF5 file containing paper embeddings and URLs
24
- """
25
  print("Initializing Paper Search Engine...")
26
 
27
- # Load Sentence Transformer model (same model used for embeddings)
28
  print("Loading Sentence Transformer model (all-roberta-large-v1)...")
29
  self.model = SentenceTransformer('sentence-transformers/all-roberta-large-v1')
30
  print("Model loaded successfully!")
31
 
32
- # Check if .h5 file exists (required for metadata and URLs)
33
  if not os.path.exists(h5_file_path):
34
  print(f"❌ Error: {h5_file_path} not found!")
35
- print(" Please ensure the h5 file is in the backend directory")
36
  raise FileNotFoundError(f"Required file not found: {h5_file_path}")
37
 
38
- # Check if .bin file exists for faster index loading
39
- bin_exists = os.path.exists(BIN_FILE_PATH)
40
- if bin_exists:
41
- print(f"⚡ Found existing HNSW index: {BIN_FILE_PATH}")
42
- else:
43
- print(f"📂 .bin file not found, will build index from embeddings")
44
-
45
  # Load embeddings and URLs from HDF5
46
  print(f"Loading embeddings from {h5_file_path}...")
47
  self.paper = h5py.File(h5_file_path, 'r')
@@ -54,72 +57,34 @@ class PaperSearchEngine:
54
 
55
  # Check if .bin file exists for faster loading
56
  if os.path.exists(BIN_FILE_PATH):
57
- print(f"Loading HNSW index from .bin file (fast mode)...")
58
  self.index = hnswlib.Index(space='cosine', dim=dim)
59
  self.index.load_index(BIN_FILE_PATH, max_elements=max_elements)
60
  self.index.set_ef(200)
61
- print("✅ HNSW index loaded successfully from .bin file!")
62
  else:
63
- # Build HNSW index from scratch if .bin doesn't exist
64
  print("Building HNSW index from scratch...")
65
- print("(This may take a while for the first run)")
66
  self.index = hnswlib.Index(space='cosine', dim=dim)
67
-
68
- # Initialize index with capacity
69
- self.index.init_index(
70
- max_elements=max_elements,
71
- ef_construction=400, # Higher = better quality, slower build
72
- M=200 # Number of connections per element
73
- )
74
-
75
- # Add embeddings to index
76
- self.index.add_items(self.paper["embeddings"], np.arange(len(self.paper["embeddings"])))
77
-
78
- # Set ef for search (higher = more accurate, slower)
79
  self.index.set_ef(200)
80
-
81
- # Save index for future runs
82
  self.index.save_index(BIN_FILE_PATH)
83
  print(f"💾 Saved HNSW index to: {BIN_FILE_PATH}")
84
- print(" (Next startup will be faster!)")
85
 
86
- print("HNSW index built successfully!")
87
  print("Paper Search Engine ready!")
88
 
89
  def text_to_vector(self, text):
90
- """
91
- Convert text to embedding vector using Sentence Transformer.
92
-
93
- Args:
94
- text: Input text (query, abstract, etc.)
95
-
96
- Returns:
97
- numpy array: L2-normalized embedding vector
98
- """
99
- # Encode text and normalize
100
  embedding = self.model.encode([text], convert_to_numpy=True, normalize_embeddings=True)
101
  return embedding[0]
102
 
103
  def extract_text_from_file(self, file_content, file_extension):
104
- """
105
- Extract text from uploaded file.
106
-
107
- Args:
108
- file_content: File content as bytes
109
- file_extension: File extension (.txt, .pdf, .md)
110
-
111
- Returns:
112
- str: Extracted text
113
- """
114
  if file_extension in ['.txt', '.md']:
115
- # Plain text files
116
  try:
117
  return file_content.decode('utf-8')
118
  except UnicodeDecodeError:
119
  return file_content.decode('latin-1')
120
-
121
  elif file_extension == '.pdf':
122
- # PDF files
123
  try:
124
  pdf_reader = PyPDF2.PdfReader(io.BytesIO(file_content))
125
  text = ""
@@ -128,68 +93,29 @@ class PaperSearchEngine:
128
  return text.strip()
129
  except Exception as e:
130
  raise ValueError(f"Error extracting text from PDF: {str(e)}")
131
-
132
  else:
133
  raise ValueError(f"Unsupported file type: {file_extension}")
134
 
135
  def search(self, query_text, k=10):
136
- """
137
- Search for similar papers using text query.
138
-
139
- Args:
140
- query_text: Text query (keywords, sentence, abstract, etc.)
141
- k: Number of results to return
142
-
143
- Returns:
144
- list: List of (url, similarity_score) tuples
145
- """
146
- # Convert query to vector
147
  query_vector = self.text_to_vector(query_text)
148
-
149
- # Search using HNSW
150
  labels, distances = self.index.knn_query(query_vector, k=k)
151
-
152
- # Convert cosine distance to similarity (1 - distance)
153
  similarities = 1 - distances[0]
154
-
155
- # Get results
156
  results = []
157
  for idx, similarity in zip(labels[0], similarities):
158
  results.append({
159
  'url': self.paper["urls"][idx].decode('utf-8'),
160
  'similarity': float(similarity)
161
  })
162
-
163
  return results
164
 
165
  def search_by_file(self, file_content, file_extension, k=10):
166
- """
167
- Search for similar papers using uploaded file content.
168
-
169
- Args:
170
- file_content: File content as bytes
171
- file_extension: File extension
172
- k: Number of results to return
173
-
174
- Returns:
175
- list: List of (url, similarity_score) tuples
176
- """
177
- # Extract text from file
178
  text = self.extract_text_from_file(file_content, file_extension)
179
-
180
- # Use regular text search
181
  return self.search(text, k)
182
 
183
-
184
- # Initialize search engine (singleton)
185
- print("Starting Flask server...")
186
- H5_FILE_PATH = os.getenv('H5_FILE_PATH', 'Papers_Embedbed_0-1000000.h5')
187
  search_engine = PaperSearchEngine(h5_file_path=H5_FILE_PATH)
188
 
189
-
190
  @app.route('/health', methods=['GET'])
191
  def health_check():
192
- """Health check endpoint"""
193
  return jsonify({
194
  'status': 'healthy',
195
  'service': 'paper-search-engine',
@@ -198,100 +124,51 @@ def health_check():
198
  'model': 'all-roberta-large-v1'
199
  })
200
 
201
-
202
  @app.route('/search', methods=['POST'])
203
  def search_text():
204
- """
205
- Text-based paper search endpoint.
206
-
207
- Expects JSON:
208
- {
209
- "query": "machine learning transformers",
210
- "k": 10
211
- }
212
- """
213
  try:
214
  data = request.get_json()
215
-
216
  if not data or 'query' not in data:
217
  return jsonify({'error': 'Missing query parameter'}), 400
218
-
219
  query = data['query']
220
  k = data.get('k', 10)
221
-
222
- # Validate k
223
  if not isinstance(k, int) or k < 1 or k > 100:
224
  return jsonify({'error': 'k must be an integer between 1 and 100'}), 400
225
-
226
- # Perform search
227
  results = search_engine.search(query, k=k)
228
-
229
  return jsonify({
230
  'query': query,
231
  'k': k,
232
  'results': results
233
  })
234
-
235
  except Exception as e:
236
  return jsonify({'error': str(e)}), 500
237
 
238
-
239
  @app.route('/search/file', methods=['POST'])
240
  def search_file():
241
- """
242
- File-based paper search endpoint.
243
-
244
- Expects multipart/form-data with:
245
- - file: The uploaded file (.txt, .pdf, .md)
246
- - k: Number of results (optional, default 10)
247
- """
248
  try:
249
- # Check if file is present
250
  if 'file' not in request.files:
251
  return jsonify({'error': 'No file provided'}), 400
252
-
253
  file = request.files['file']
254
-
255
  if file.filename == '':
256
  return jsonify({'error': 'Empty filename'}), 400
257
-
258
- # Get file extension
259
  file_extension = os.path.splitext(file.filename)[1].lower()
260
-
261
  if file_extension not in ['.txt', '.pdf', '.md']:
262
  return jsonify({'error': 'Unsupported file type. Supported: .txt, .pdf, .md'}), 400
263
-
264
- # Read file content
265
  file_content = file.read()
266
-
267
- # Get k parameter
268
  k = request.form.get('k', 10, type=int)
269
-
270
- # Validate k
271
  if k < 1 or k > 100:
272
  return jsonify({'error': 'k must be between 1 and 100'}), 400
273
-
274
- # Perform search
275
  results = search_engine.search_by_file(file_content, file_extension, k=k)
276
-
277
  return jsonify({
278
  'filename': file.filename,
279
  'k': k,
280
  'results': results
281
  })
282
-
283
  except ValueError as e:
284
  return jsonify({'error': str(e)}), 400
285
  except Exception as e:
286
  return jsonify({'error': str(e)}), 500
287
 
288
-
289
  if __name__ == '__main__':
290
- port = int(os.getenv('PORT', 5001)) # Default to 5001 to avoid conflict with image server
291
- debug = os.getenv('FLASK_DEBUG', '0') == '1'
292
-
293
- print(f"\nPaper Search Engine running on http://localhost:{port}")
294
- print(f"Health check: http://localhost:{port}/health")
295
- print(f"Total papers indexed: {len(search_engine.paper['urls'])}")
296
-
297
- app.run(host='0.0.0.0', port=port, debug=debug)
 
7
  from sentence_transformers import SentenceTransformer
8
  import PyPDF2
9
  import io
10
+ from huggingface_hub import hf_hub_download
11
 
12
  app = Flask(__name__)
13
  CORS(app, origins=['*'])
14
+
15
+ print("\n" + "="*50)
16
+ print("📥 INITIALIZING PAPER SERVER...")
17
+ print("="*50)
18
+
19
+ # Cấu hình Dataset
20
+ HF_TOKEN = os.environ.get("HF_TOKEN")
21
+ DATASET_ID = "huynguyen6906/Image_server_data" # Thay bằng dataset ID của bạn cho paper nếu khác
22
+
23
+ # Tải file từ Hugging Face Dataset
24
+ try:
25
+ print(f"Downloading data from {DATASET_ID}...")
26
+ H5_FILE_PATH = hf_hub_download(repo_id=DATASET_ID, filename="Papers_Embedbed_0-1000000.h5", repo_type="dataset", token=HF_TOKEN)
27
+ BIN_FILE_PATH = hf_hub_download(repo_id=DATASET_ID, filename="hnsw_paper_index.bin", repo_type="dataset", token=HF_TOKEN)
28
+ print(f"✅ Data loaded: {H5_FILE_PATH}")
29
+ except Exception as e:
30
+ print(f"❌ Error downloading data: {str(e)}")
31
+ H5_FILE_PATH = 'Papers_Embedbed_0-1000000.h5'
32
+ BIN_FILE_PATH = 'hnsw_paper_index.bin'
33
 
34
  class PaperSearchEngine:
35
  def __init__(self, h5_file_path=H5_FILE_PATH):
 
 
 
 
 
 
36
  print("Initializing Paper Search Engine...")
37
 
38
+ # Load Sentence Transformer model
39
  print("Loading Sentence Transformer model (all-roberta-large-v1)...")
40
  self.model = SentenceTransformer('sentence-transformers/all-roberta-large-v1')
41
  print("Model loaded successfully!")
42
 
43
+ # Check if .h5 file exists
44
  if not os.path.exists(h5_file_path):
45
  print(f"❌ Error: {h5_file_path} not found!")
 
46
  raise FileNotFoundError(f"Required file not found: {h5_file_path}")
47
 
 
 
 
 
 
 
 
48
  # Load embeddings and URLs from HDF5
49
  print(f"Loading embeddings from {h5_file_path}...")
50
  self.paper = h5py.File(h5_file_path, 'r')
 
57
 
58
  # Check if .bin file exists for faster loading
59
  if os.path.exists(BIN_FILE_PATH):
60
+ print(f"Loading HNSW index from {BIN_FILE_PATH}...")
61
  self.index = hnswlib.Index(space='cosine', dim=dim)
62
  self.index.load_index(BIN_FILE_PATH, max_elements=max_elements)
63
  self.index.set_ef(200)
64
+ print("✅ HNSW index loaded!")
65
  else:
66
+ # Build HNSW index from scratch
67
  print("Building HNSW index from scratch...")
 
68
  self.index = hnswlib.Index(space='cosine', dim=dim)
69
+ self.index.init_index(max_elements=max_elements, ef_construction=400, M=200)
70
+ self.index.add_items(self.paper["embeddings"])
 
 
 
 
 
 
 
 
 
 
71
  self.index.set_ef(200)
 
 
72
  self.index.save_index(BIN_FILE_PATH)
73
  print(f"💾 Saved HNSW index to: {BIN_FILE_PATH}")
 
74
 
 
75
  print("Paper Search Engine ready!")
76
 
77
  def text_to_vector(self, text):
 
 
 
 
 
 
 
 
 
 
78
  embedding = self.model.encode([text], convert_to_numpy=True, normalize_embeddings=True)
79
  return embedding[0]
80
 
81
  def extract_text_from_file(self, file_content, file_extension):
 
 
 
 
 
 
 
 
 
 
82
  if file_extension in ['.txt', '.md']:
 
83
  try:
84
  return file_content.decode('utf-8')
85
  except UnicodeDecodeError:
86
  return file_content.decode('latin-1')
 
87
  elif file_extension == '.pdf':
 
88
  try:
89
  pdf_reader = PyPDF2.PdfReader(io.BytesIO(file_content))
90
  text = ""
 
93
  return text.strip()
94
  except Exception as e:
95
  raise ValueError(f"Error extracting text from PDF: {str(e)}")
 
96
  else:
97
  raise ValueError(f"Unsupported file type: {file_extension}")
98
 
99
  def search(self, query_text, k=10):
 
 
 
 
 
 
 
 
 
 
 
100
  query_vector = self.text_to_vector(query_text)
 
 
101
  labels, distances = self.index.knn_query(query_vector, k=k)
 
 
102
  similarities = 1 - distances[0]
 
 
103
  results = []
104
  for idx, similarity in zip(labels[0], similarities):
105
  results.append({
106
  'url': self.paper["urls"][idx].decode('utf-8'),
107
  'similarity': float(similarity)
108
  })
 
109
  return results
110
 
111
  def search_by_file(self, file_content, file_extension, k=10):
 
 
 
 
 
 
 
 
 
 
 
 
112
  text = self.extract_text_from_file(file_content, file_extension)
 
 
113
  return self.search(text, k)
114
 
 
 
 
 
115
  search_engine = PaperSearchEngine(h5_file_path=H5_FILE_PATH)
116
 
 
117
  @app.route('/health', methods=['GET'])
118
  def health_check():
 
119
  return jsonify({
120
  'status': 'healthy',
121
  'service': 'paper-search-engine',
 
124
  'model': 'all-roberta-large-v1'
125
  })
126
 
 
127
  @app.route('/search', methods=['POST'])
128
  def search_text():
 
 
 
 
 
 
 
 
 
129
  try:
130
  data = request.get_json()
 
131
  if not data or 'query' not in data:
132
  return jsonify({'error': 'Missing query parameter'}), 400
 
133
  query = data['query']
134
  k = data.get('k', 10)
 
 
135
  if not isinstance(k, int) or k < 1 or k > 100:
136
  return jsonify({'error': 'k must be an integer between 1 and 100'}), 400
 
 
137
  results = search_engine.search(query, k=k)
 
138
  return jsonify({
139
  'query': query,
140
  'k': k,
141
  'results': results
142
  })
 
143
  except Exception as e:
144
  return jsonify({'error': str(e)}), 500
145
 
 
146
  @app.route('/search/file', methods=['POST'])
147
  def search_file():
 
 
 
 
 
 
 
148
  try:
 
149
  if 'file' not in request.files:
150
  return jsonify({'error': 'No file provided'}), 400
 
151
  file = request.files['file']
 
152
  if file.filename == '':
153
  return jsonify({'error': 'Empty filename'}), 400
 
 
154
  file_extension = os.path.splitext(file.filename)[1].lower()
 
155
  if file_extension not in ['.txt', '.pdf', '.md']:
156
  return jsonify({'error': 'Unsupported file type. Supported: .txt, .pdf, .md'}), 400
 
 
157
  file_content = file.read()
 
 
158
  k = request.form.get('k', 10, type=int)
 
 
159
  if k < 1 or k > 100:
160
  return jsonify({'error': 'k must be between 1 and 100'}), 400
 
 
161
  results = search_engine.search_by_file(file_content, file_extension, k=k)
 
162
  return jsonify({
163
  'filename': file.filename,
164
  'k': k,
165
  'results': results
166
  })
 
167
  except ValueError as e:
168
  return jsonify({'error': str(e)}), 400
169
  except Exception as e:
170
  return jsonify({'error': str(e)}), 500
171
 
 
172
  if __name__ == '__main__':
173
+ port = 7860 # Chuẩn cho HF Spaces
174
+ app.run(host='0.0.0.0', port=port)