import os import numpy as np import h5py import hnswlib from flask import Flask, request, jsonify from flask_cors import CORS from sentence_transformers import SentenceTransformer import PyPDF2 import io from huggingface_hub import hf_hub_download app = Flask(__name__) CORS(app, origins=['*']) print("\n" + "="*50) print("📥 INITIALIZING PAPER SERVER...") print("="*50) # Cấu hình Dataset HF_TOKEN = os.environ.get("HF_TOKEN") DATASET_ID = "huynguyen6906/Paper_server_data" # Thay bằng dataset ID của bạn cho paper nếu khác # Tải file từ Hugging Face Dataset try: print(f"Downloading data from {DATASET_ID}...") H5_FILE_PATH = hf_hub_download(repo_id=DATASET_ID, filename="Papers_Embedbed_0-1000000.h5", repo_type="dataset", token=HF_TOKEN) BIN_FILE_PATH = hf_hub_download(repo_id=DATASET_ID, filename="hnsw_paper_index.bin", repo_type="dataset", token=HF_TOKEN) print(f"✅ Data loaded: {H5_FILE_PATH}") except Exception as e: print(f"❌ Error downloading data: {str(e)}") H5_FILE_PATH = 'Papers_Embedbed_0-1000000.h5' BIN_FILE_PATH = 'hnsw_paper_index.bin' class PaperSearchEngine: def __init__(self, h5_file_path=H5_FILE_PATH): print("Initializing Paper Search Engine...") # Load Sentence Transformer model print("Loading Sentence Transformer model (all-roberta-large-v1)...") self.model = SentenceTransformer('sentence-transformers/all-roberta-large-v1') print("Model loaded successfully!") # Check if .h5 file exists if not os.path.exists(h5_file_path): print(f"❌ Error: {h5_file_path} not found!") raise FileNotFoundError(f"Required file not found: {h5_file_path}") # Load embeddings and URLs from HDF5 print(f"Loading embeddings from {h5_file_path}...") self.paper = h5py.File(h5_file_path, 'r') print(f"Loaded {len(self.paper['urls'])} paper embeddings") print(f"Embedding dimension: {self.paper['embeddings'].shape[1]}") dim = self.paper["embeddings"].shape[1] max_elements = len(self.paper["embeddings"]) # Check if .bin file exists for faster loading if os.path.exists(BIN_FILE_PATH): print(f"⚡ Loading HNSW index from {BIN_FILE_PATH}...") self.index = hnswlib.Index(space='cosine', dim=dim) self.index.load_index(BIN_FILE_PATH, max_elements=max_elements) self.index.set_ef(200) print("✅ HNSW index loaded!") else: # Build HNSW index from scratch print("Building HNSW index from scratch...") self.index = hnswlib.Index(space='cosine', dim=dim) self.index.init_index(max_elements=max_elements, ef_construction=400, M=200) self.index.add_items(self.paper["embeddings"]) self.index.set_ef(200) self.index.save_index(BIN_FILE_PATH) print(f"💾 Saved HNSW index to: {BIN_FILE_PATH}") print("Paper Search Engine ready!") def text_to_vector(self, text): embedding = self.model.encode([text], convert_to_numpy=True, normalize_embeddings=True) return embedding[0] def extract_text_from_file(self, file_content, file_extension): if file_extension in ['.txt', '.md']: try: return file_content.decode('utf-8') except UnicodeDecodeError: return file_content.decode('latin-1') elif file_extension == '.pdf': try: pdf_reader = PyPDF2.PdfReader(io.BytesIO(file_content)) text = "" for page in pdf_reader.pages: text += page.extract_text() + "\n" return text.strip() except Exception as e: raise ValueError(f"Error extracting text from PDF: {str(e)}") else: raise ValueError(f"Unsupported file type: {file_extension}") def search(self, query_text, k=10): query_vector = self.text_to_vector(query_text) labels, distances = self.index.knn_query(query_vector, k=k) similarities = 1 - distances[0] results = [] for idx, similarity in zip(labels[0], similarities): results.append({ 'url': self.paper["urls"][idx].decode('utf-8'), 'similarity': float(similarity) }) return results def search_by_file(self, file_content, file_extension, k=10): text = self.extract_text_from_file(file_content, file_extension) return self.search(text, k) search_engine = PaperSearchEngine(h5_file_path=H5_FILE_PATH) @app.route('/health', methods=['GET']) def health_check(): return jsonify({ 'status': 'healthy', 'service': 'paper-search-engine', 'total_papers': len(search_engine.paper["urls"]), 'embedding_dim': search_engine.paper["embeddings"].shape[1], 'model': 'all-roberta-large-v1' }) @app.route('/search', methods=['POST']) def search_text(): try: data = request.get_json() if not data or 'query' not in data: return jsonify({'error': 'Missing query parameter'}), 400 query = data['query'] k = data.get('k', 10) if not isinstance(k, int) or k < 1 or k > 100: return jsonify({'error': 'k must be an integer between 1 and 100'}), 400 results = search_engine.search(query, k=k) return jsonify({ 'query': query, 'k': k, 'results': results }) except Exception as e: return jsonify({'error': str(e)}), 500 @app.route('/search/file', methods=['POST']) def search_file(): try: if 'file' not in request.files: return jsonify({'error': 'No file provided'}), 400 file = request.files['file'] if file.filename == '': return jsonify({'error': 'Empty filename'}), 400 file_extension = os.path.splitext(file.filename)[1].lower() if file_extension not in ['.txt', '.pdf', '.md']: return jsonify({'error': 'Unsupported file type. Supported: .txt, .pdf, .md'}), 400 file_content = file.read() k = request.form.get('k', 10, type=int) if k < 1 or k > 100: return jsonify({'error': 'k must be between 1 and 100'}), 400 results = search_engine.search_by_file(file_content, file_extension, k=k) return jsonify({ 'filename': file.filename, 'k': k, 'results': results }) except ValueError as e: return jsonify({'error': str(e)}), 400 except Exception as e: return jsonify({'error': str(e)}), 500 if __name__ == '__main__': port = 7860 # Chuẩn cho HF Spaces app.run(host='0.0.0.0', port=port)