huynguyen6906 commited on
Commit
fff5f84
·
verified ·
1 Parent(s): 2e0a39a

Upload server_paper_RAM_optimize.py

Browse files
Files changed (1) hide show
  1. server_paper_RAM_optimize.py +297 -0
server_paper_RAM_optimize.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import h5py
4
+ import hnswlib
5
+ from flask import Flask, request, jsonify
6
+ from flask_cors import CORS
7
+ from sentence_transformers import SentenceTransformer
8
+ import PyPDF2
9
+ import io
10
+
11
+ app = Flask(__name__)
12
+ CORS(app, origins=['*'])
13
+ H5_FILE_PATH='Papers_Embedbed_0-1000000.h5'
14
+ BIN_FILE_PATH='hnsw_paper_index.bin'
15
+ os.environ['H5_FILE_PATH'] = H5_FILE_PATH
16
+
17
+ class PaperSearchEngine:
18
+ def __init__(self, h5_file_path=H5_FILE_PATH):
19
+ """
20
+ Initialize the Paper Search Engine with Sentence Transformers and HNSW index.
21
+
22
+ Args:
23
+ h5_file_path: Path to the HDF5 file containing paper embeddings and URLs
24
+ """
25
+ print("Initializing Paper Search Engine...")
26
+
27
+ # Load Sentence Transformer model (same model used for embeddings)
28
+ print("Loading Sentence Transformer model (all-roberta-large-v1)...")
29
+ self.model = SentenceTransformer('sentence-transformers/all-roberta-large-v1')
30
+ print("Model loaded successfully!")
31
+
32
+ # Check if .h5 file exists (required for metadata and URLs)
33
+ if not os.path.exists(h5_file_path):
34
+ print(f"❌ Error: {h5_file_path} not found!")
35
+ print(" Please ensure the h5 file is in the backend directory")
36
+ raise FileNotFoundError(f"Required file not found: {h5_file_path}")
37
+
38
+ # Check if .bin file exists for faster index loading
39
+ bin_exists = os.path.exists(BIN_FILE_PATH)
40
+ if bin_exists:
41
+ print(f"⚡ Found existing HNSW index: {BIN_FILE_PATH}")
42
+ else:
43
+ print(f"📂 .bin file not found, will build index from embeddings")
44
+
45
+ # Load embeddings and URLs from HDF5
46
+ print(f"Loading embeddings from {h5_file_path}...")
47
+ self.paper = h5py.File(h5_file_path, 'r')
48
+
49
+ print(f"Loaded {len(self.paper['urls'])} paper embeddings")
50
+ print(f"Embedding dimension: {self.paper['embeddings'].shape[1]}")
51
+
52
+ dim = self.paper["embeddings"].shape[1]
53
+ max_elements = len(self.paper["embeddings"])
54
+
55
+ # Check if .bin file exists for faster loading
56
+ if os.path.exists(BIN_FILE_PATH):
57
+ print(f"Loading HNSW index from .bin file (fast mode)...")
58
+ self.index = hnswlib.Index(space='cosine', dim=dim)
59
+ self.index.load_index(BIN_FILE_PATH, max_elements=max_elements)
60
+ self.index.set_ef(200)
61
+ print("✅ HNSW index loaded successfully from .bin file!")
62
+ else:
63
+ # Build HNSW index from scratch if .bin doesn't exist
64
+ print("Building HNSW index from scratch...")
65
+ print("(This may take a while for the first run)")
66
+ self.index = hnswlib.Index(space='cosine', dim=dim)
67
+
68
+ # Initialize index with capacity
69
+ self.index.init_index(
70
+ max_elements=max_elements,
71
+ ef_construction=400, # Higher = better quality, slower build
72
+ M=200 # Number of connections per element
73
+ )
74
+
75
+ # Add embeddings to index
76
+ self.index.add_items(self.paper["embeddings"], np.arange(len(self.paper["embeddings"])))
77
+
78
+ # Set ef for search (higher = more accurate, slower)
79
+ self.index.set_ef(200)
80
+
81
+ # Save index for future runs
82
+ self.index.save_index(BIN_FILE_PATH)
83
+ print(f"💾 Saved HNSW index to: {BIN_FILE_PATH}")
84
+ print(" (Next startup will be faster!)")
85
+
86
+ print("HNSW index built successfully!")
87
+ print("Paper Search Engine ready!")
88
+
89
+ def text_to_vector(self, text):
90
+ """
91
+ Convert text to embedding vector using Sentence Transformer.
92
+
93
+ Args:
94
+ text: Input text (query, abstract, etc.)
95
+
96
+ Returns:
97
+ numpy array: L2-normalized embedding vector
98
+ """
99
+ # Encode text and normalize
100
+ embedding = self.model.encode([text], convert_to_numpy=True, normalize_embeddings=True)
101
+ return embedding[0]
102
+
103
+ def extract_text_from_file(self, file_content, file_extension):
104
+ """
105
+ Extract text from uploaded file.
106
+
107
+ Args:
108
+ file_content: File content as bytes
109
+ file_extension: File extension (.txt, .pdf, .md)
110
+
111
+ Returns:
112
+ str: Extracted text
113
+ """
114
+ if file_extension in ['.txt', '.md']:
115
+ # Plain text files
116
+ try:
117
+ return file_content.decode('utf-8')
118
+ except UnicodeDecodeError:
119
+ return file_content.decode('latin-1')
120
+
121
+ elif file_extension == '.pdf':
122
+ # PDF files
123
+ try:
124
+ pdf_reader = PyPDF2.PdfReader(io.BytesIO(file_content))
125
+ text = ""
126
+ for page in pdf_reader.pages:
127
+ text += page.extract_text() + "\n"
128
+ return text.strip()
129
+ except Exception as e:
130
+ raise ValueError(f"Error extracting text from PDF: {str(e)}")
131
+
132
+ else:
133
+ raise ValueError(f"Unsupported file type: {file_extension}")
134
+
135
+ def search(self, query_text, k=10):
136
+ """
137
+ Search for similar papers using text query.
138
+
139
+ Args:
140
+ query_text: Text query (keywords, sentence, abstract, etc.)
141
+ k: Number of results to return
142
+
143
+ Returns:
144
+ list: List of (url, similarity_score) tuples
145
+ """
146
+ # Convert query to vector
147
+ query_vector = self.text_to_vector(query_text)
148
+
149
+ # Search using HNSW
150
+ labels, distances = self.index.knn_query(query_vector, k=k)
151
+
152
+ # Convert cosine distance to similarity (1 - distance)
153
+ similarities = 1 - distances[0]
154
+
155
+ # Get results
156
+ results = []
157
+ for idx, similarity in zip(labels[0], similarities):
158
+ results.append({
159
+ 'url': self.paper["urls"][idx].decode('utf-8'),
160
+ 'similarity': float(similarity)
161
+ })
162
+
163
+ return results
164
+
165
+ def search_by_file(self, file_content, file_extension, k=10):
166
+ """
167
+ Search for similar papers using uploaded file content.
168
+
169
+ Args:
170
+ file_content: File content as bytes
171
+ file_extension: File extension
172
+ k: Number of results to return
173
+
174
+ Returns:
175
+ list: List of (url, similarity_score) tuples
176
+ """
177
+ # Extract text from file
178
+ text = self.extract_text_from_file(file_content, file_extension)
179
+
180
+ # Use regular text search
181
+ return self.search(text, k)
182
+
183
+
184
+ # Initialize search engine (singleton)
185
+ print("Starting Flask server...")
186
+ H5_FILE_PATH = os.getenv('H5_FILE_PATH', 'Papers_Embedbed_0-1000000.h5')
187
+ search_engine = PaperSearchEngine(h5_file_path=H5_FILE_PATH)
188
+
189
+
190
+ @app.route('/health', methods=['GET'])
191
+ def health_check():
192
+ """Health check endpoint"""
193
+ return jsonify({
194
+ 'status': 'healthy',
195
+ 'service': 'paper-search-engine',
196
+ 'total_papers': len(search_engine.paper["urls"]),
197
+ 'embedding_dim': search_engine.paper["embeddings"].shape[1],
198
+ 'model': 'all-roberta-large-v1'
199
+ })
200
+
201
+
202
+ @app.route('/search', methods=['POST'])
203
+ def search_text():
204
+ """
205
+ Text-based paper search endpoint.
206
+
207
+ Expects JSON:
208
+ {
209
+ "query": "machine learning transformers",
210
+ "k": 10
211
+ }
212
+ """
213
+ try:
214
+ data = request.get_json()
215
+
216
+ if not data or 'query' not in data:
217
+ return jsonify({'error': 'Missing query parameter'}), 400
218
+
219
+ query = data['query']
220
+ k = data.get('k', 10)
221
+
222
+ # Validate k
223
+ if not isinstance(k, int) or k < 1 or k > 100:
224
+ return jsonify({'error': 'k must be an integer between 1 and 100'}), 400
225
+
226
+ # Perform search
227
+ results = search_engine.search(query, k=k)
228
+
229
+ return jsonify({
230
+ 'query': query,
231
+ 'k': k,
232
+ 'results': results
233
+ })
234
+
235
+ except Exception as e:
236
+ return jsonify({'error': str(e)}), 500
237
+
238
+
239
+ @app.route('/search/file', methods=['POST'])
240
+ def search_file():
241
+ """
242
+ File-based paper search endpoint.
243
+
244
+ Expects multipart/form-data with:
245
+ - file: The uploaded file (.txt, .pdf, .md)
246
+ - k: Number of results (optional, default 10)
247
+ """
248
+ try:
249
+ # Check if file is present
250
+ if 'file' not in request.files:
251
+ return jsonify({'error': 'No file provided'}), 400
252
+
253
+ file = request.files['file']
254
+
255
+ if file.filename == '':
256
+ return jsonify({'error': 'Empty filename'}), 400
257
+
258
+ # Get file extension
259
+ file_extension = os.path.splitext(file.filename)[1].lower()
260
+
261
+ if file_extension not in ['.txt', '.pdf', '.md']:
262
+ return jsonify({'error': 'Unsupported file type. Supported: .txt, .pdf, .md'}), 400
263
+
264
+ # Read file content
265
+ file_content = file.read()
266
+
267
+ # Get k parameter
268
+ k = request.form.get('k', 10, type=int)
269
+
270
+ # Validate k
271
+ if k < 1 or k > 100:
272
+ return jsonify({'error': 'k must be between 1 and 100'}), 400
273
+
274
+ # Perform search
275
+ results = search_engine.search_by_file(file_content, file_extension, k=k)
276
+
277
+ return jsonify({
278
+ 'filename': file.filename,
279
+ 'k': k,
280
+ 'results': results
281
+ })
282
+
283
+ except ValueError as e:
284
+ return jsonify({'error': str(e)}), 400
285
+ except Exception as e:
286
+ return jsonify({'error': str(e)}), 500
287
+
288
+
289
+ if __name__ == '__main__':
290
+ port = int(os.getenv('PORT', 5001)) # Default to 5001 to avoid conflict with image server
291
+ debug = os.getenv('FLASK_DEBUG', '0') == '1'
292
+
293
+ print(f"\nPaper Search Engine running on http://localhost:{port}")
294
+ print(f"Health check: http://localhost:{port}/health")
295
+ print(f"Total papers indexed: {len(search_engine.paper['urls'])}")
296
+
297
+ app.run(host='0.0.0.0', port=port, debug=debug)