PAPER_SERVER / server_paper_RAM_optimize.py
huynguyen6906's picture
Update server_paper_RAM_optimize.py
ff255f8 verified
import os
import numpy as np
import h5py
import hnswlib
from flask import Flask, request, jsonify
from flask_cors import CORS
from sentence_transformers import SentenceTransformer
import PyPDF2
import io
from huggingface_hub import hf_hub_download
app = Flask(__name__)
CORS(app, origins=['*'])
print("\n" + "="*50)
print("📥 INITIALIZING PAPER SERVER...")
print("="*50)
# Cấu hình Dataset
HF_TOKEN = os.environ.get("HF_TOKEN")
DATASET_ID = "huynguyen6906/Paper_server_data" # Thay bằng dataset ID của bạn cho paper nếu khác
# Tải file từ Hugging Face Dataset
try:
print(f"Downloading data from {DATASET_ID}...")
H5_FILE_PATH = hf_hub_download(repo_id=DATASET_ID, filename="Papers_Embedbed_0-1000000.h5", repo_type="dataset", token=HF_TOKEN)
BIN_FILE_PATH = hf_hub_download(repo_id=DATASET_ID, filename="hnsw_paper_index.bin", repo_type="dataset", token=HF_TOKEN)
print(f"✅ Data loaded: {H5_FILE_PATH}")
except Exception as e:
print(f"❌ Error downloading data: {str(e)}")
H5_FILE_PATH = 'Papers_Embedbed_0-1000000.h5'
BIN_FILE_PATH = 'hnsw_paper_index.bin'
class PaperSearchEngine:
def __init__(self, h5_file_path=H5_FILE_PATH):
print("Initializing Paper Search Engine...")
# Load Sentence Transformer model
print("Loading Sentence Transformer model (all-roberta-large-v1)...")
self.model = SentenceTransformer('sentence-transformers/all-roberta-large-v1')
print("Model loaded successfully!")
# Check if .h5 file exists
if not os.path.exists(h5_file_path):
print(f"❌ Error: {h5_file_path} not found!")
raise FileNotFoundError(f"Required file not found: {h5_file_path}")
# Load embeddings and URLs from HDF5
print(f"Loading embeddings from {h5_file_path}...")
self.paper = h5py.File(h5_file_path, 'r')
print(f"Loaded {len(self.paper['urls'])} paper embeddings")
print(f"Embedding dimension: {self.paper['embeddings'].shape[1]}")
dim = self.paper["embeddings"].shape[1]
max_elements = len(self.paper["embeddings"])
# Check if .bin file exists for faster loading
if os.path.exists(BIN_FILE_PATH):
print(f"⚡ Loading HNSW index from {BIN_FILE_PATH}...")
self.index = hnswlib.Index(space='cosine', dim=dim)
self.index.load_index(BIN_FILE_PATH, max_elements=max_elements)
self.index.set_ef(200)
print("✅ HNSW index loaded!")
else:
# Build HNSW index from scratch
print("Building HNSW index from scratch...")
self.index = hnswlib.Index(space='cosine', dim=dim)
self.index.init_index(max_elements=max_elements, ef_construction=400, M=200)
self.index.add_items(self.paper["embeddings"])
self.index.set_ef(200)
self.index.save_index(BIN_FILE_PATH)
print(f"💾 Saved HNSW index to: {BIN_FILE_PATH}")
print("Paper Search Engine ready!")
def text_to_vector(self, text):
embedding = self.model.encode([text], convert_to_numpy=True, normalize_embeddings=True)
return embedding[0]
def extract_text_from_file(self, file_content, file_extension):
if file_extension in ['.txt', '.md']:
try:
return file_content.decode('utf-8')
except UnicodeDecodeError:
return file_content.decode('latin-1')
elif file_extension == '.pdf':
try:
pdf_reader = PyPDF2.PdfReader(io.BytesIO(file_content))
text = ""
for page in pdf_reader.pages:
text += page.extract_text() + "\n"
return text.strip()
except Exception as e:
raise ValueError(f"Error extracting text from PDF: {str(e)}")
else:
raise ValueError(f"Unsupported file type: {file_extension}")
def search(self, query_text, k=10):
query_vector = self.text_to_vector(query_text)
labels, distances = self.index.knn_query(query_vector, k=k)
similarities = 1 - distances[0]
results = []
for idx, similarity in zip(labels[0], similarities):
results.append({
'url': self.paper["urls"][idx].decode('utf-8'),
'similarity': float(similarity)
})
return results
def search_by_file(self, file_content, file_extension, k=10):
text = self.extract_text_from_file(file_content, file_extension)
return self.search(text, k)
search_engine = PaperSearchEngine(h5_file_path=H5_FILE_PATH)
@app.route('/health', methods=['GET'])
def health_check():
return jsonify({
'status': 'healthy',
'service': 'paper-search-engine',
'total_papers': len(search_engine.paper["urls"]),
'embedding_dim': search_engine.paper["embeddings"].shape[1],
'model': 'all-roberta-large-v1'
})
@app.route('/search', methods=['POST'])
def search_text():
try:
data = request.get_json()
if not data or 'query' not in data:
return jsonify({'error': 'Missing query parameter'}), 400
query = data['query']
k = data.get('k', 10)
if not isinstance(k, int) or k < 1 or k > 100:
return jsonify({'error': 'k must be an integer between 1 and 100'}), 400
results = search_engine.search(query, k=k)
return jsonify({
'query': query,
'k': k,
'results': results
})
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/search/file', methods=['POST'])
def search_file():
try:
if 'file' not in request.files:
return jsonify({'error': 'No file provided'}), 400
file = request.files['file']
if file.filename == '':
return jsonify({'error': 'Empty filename'}), 400
file_extension = os.path.splitext(file.filename)[1].lower()
if file_extension not in ['.txt', '.pdf', '.md']:
return jsonify({'error': 'Unsupported file type. Supported: .txt, .pdf, .md'}), 400
file_content = file.read()
k = request.form.get('k', 10, type=int)
if k < 1 or k > 100:
return jsonify({'error': 'k must be between 1 and 100'}), 400
results = search_engine.search_by_file(file_content, file_extension, k=k)
return jsonify({
'filename': file.filename,
'k': k,
'results': results
})
except ValueError as e:
return jsonify({'error': str(e)}), 400
except Exception as e:
return jsonify({'error': str(e)}), 500
if __name__ == '__main__':
port = 7860 # Chuẩn cho HF Spaces
app.run(host='0.0.0.0', port=port)