Spaces:
Running
Running
| import os | |
| import numpy as np | |
| import h5py | |
| import hnswlib | |
| from flask import Flask, request, jsonify | |
| from flask_cors import CORS | |
| from sentence_transformers import SentenceTransformer | |
| import PyPDF2 | |
| import io | |
| from huggingface_hub import hf_hub_download | |
| app = Flask(__name__) | |
| CORS(app, origins=['*']) | |
| print("\n" + "="*50) | |
| print("📥 INITIALIZING PAPER SERVER...") | |
| print("="*50) | |
| # Cấu hình Dataset | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| DATASET_ID = "huynguyen6906/Paper_server_data" # Thay bằng dataset ID của bạn cho paper nếu khác | |
| # Tải file từ Hugging Face Dataset | |
| try: | |
| print(f"Downloading data from {DATASET_ID}...") | |
| H5_FILE_PATH = hf_hub_download(repo_id=DATASET_ID, filename="Papers_Embedbed_0-1000000.h5", repo_type="dataset", token=HF_TOKEN) | |
| BIN_FILE_PATH = hf_hub_download(repo_id=DATASET_ID, filename="hnsw_paper_index.bin", repo_type="dataset", token=HF_TOKEN) | |
| print(f"✅ Data loaded: {H5_FILE_PATH}") | |
| except Exception as e: | |
| print(f"❌ Error downloading data: {str(e)}") | |
| H5_FILE_PATH = 'Papers_Embedbed_0-1000000.h5' | |
| BIN_FILE_PATH = 'hnsw_paper_index.bin' | |
| class PaperSearchEngine: | |
| def __init__(self, h5_file_path=H5_FILE_PATH): | |
| print("Initializing Paper Search Engine...") | |
| # Load Sentence Transformer model | |
| print("Loading Sentence Transformer model (all-roberta-large-v1)...") | |
| self.model = SentenceTransformer('sentence-transformers/all-roberta-large-v1') | |
| print("Model loaded successfully!") | |
| # Check if .h5 file exists | |
| if not os.path.exists(h5_file_path): | |
| print(f"❌ Error: {h5_file_path} not found!") | |
| raise FileNotFoundError(f"Required file not found: {h5_file_path}") | |
| # Load embeddings and URLs from HDF5 | |
| print(f"Loading embeddings from {h5_file_path}...") | |
| self.paper = h5py.File(h5_file_path, 'r') | |
| print(f"Loaded {len(self.paper['urls'])} paper embeddings") | |
| print(f"Embedding dimension: {self.paper['embeddings'].shape[1]}") | |
| dim = self.paper["embeddings"].shape[1] | |
| max_elements = len(self.paper["embeddings"]) | |
| # Check if .bin file exists for faster loading | |
| if os.path.exists(BIN_FILE_PATH): | |
| print(f"⚡ Loading HNSW index from {BIN_FILE_PATH}...") | |
| self.index = hnswlib.Index(space='cosine', dim=dim) | |
| self.index.load_index(BIN_FILE_PATH, max_elements=max_elements) | |
| self.index.set_ef(200) | |
| print("✅ HNSW index loaded!") | |
| else: | |
| # Build HNSW index from scratch | |
| print("Building HNSW index from scratch...") | |
| self.index = hnswlib.Index(space='cosine', dim=dim) | |
| self.index.init_index(max_elements=max_elements, ef_construction=400, M=200) | |
| self.index.add_items(self.paper["embeddings"]) | |
| self.index.set_ef(200) | |
| self.index.save_index(BIN_FILE_PATH) | |
| print(f"💾 Saved HNSW index to: {BIN_FILE_PATH}") | |
| print("Paper Search Engine ready!") | |
| def text_to_vector(self, text): | |
| embedding = self.model.encode([text], convert_to_numpy=True, normalize_embeddings=True) | |
| return embedding[0] | |
| def extract_text_from_file(self, file_content, file_extension): | |
| if file_extension in ['.txt', '.md']: | |
| try: | |
| return file_content.decode('utf-8') | |
| except UnicodeDecodeError: | |
| return file_content.decode('latin-1') | |
| elif file_extension == '.pdf': | |
| try: | |
| pdf_reader = PyPDF2.PdfReader(io.BytesIO(file_content)) | |
| text = "" | |
| for page in pdf_reader.pages: | |
| text += page.extract_text() + "\n" | |
| return text.strip() | |
| except Exception as e: | |
| raise ValueError(f"Error extracting text from PDF: {str(e)}") | |
| else: | |
| raise ValueError(f"Unsupported file type: {file_extension}") | |
| def search(self, query_text, k=10): | |
| query_vector = self.text_to_vector(query_text) | |
| labels, distances = self.index.knn_query(query_vector, k=k) | |
| similarities = 1 - distances[0] | |
| results = [] | |
| for idx, similarity in zip(labels[0], similarities): | |
| results.append({ | |
| 'url': self.paper["urls"][idx].decode('utf-8'), | |
| 'similarity': float(similarity) | |
| }) | |
| return results | |
| def search_by_file(self, file_content, file_extension, k=10): | |
| text = self.extract_text_from_file(file_content, file_extension) | |
| return self.search(text, k) | |
| search_engine = PaperSearchEngine(h5_file_path=H5_FILE_PATH) | |
| def health_check(): | |
| return jsonify({ | |
| 'status': 'healthy', | |
| 'service': 'paper-search-engine', | |
| 'total_papers': len(search_engine.paper["urls"]), | |
| 'embedding_dim': search_engine.paper["embeddings"].shape[1], | |
| 'model': 'all-roberta-large-v1' | |
| }) | |
| def search_text(): | |
| try: | |
| data = request.get_json() | |
| if not data or 'query' not in data: | |
| return jsonify({'error': 'Missing query parameter'}), 400 | |
| query = data['query'] | |
| k = data.get('k', 10) | |
| if not isinstance(k, int) or k < 1 or k > 100: | |
| return jsonify({'error': 'k must be an integer between 1 and 100'}), 400 | |
| results = search_engine.search(query, k=k) | |
| return jsonify({ | |
| 'query': query, | |
| 'k': k, | |
| 'results': results | |
| }) | |
| except Exception as e: | |
| return jsonify({'error': str(e)}), 500 | |
| def search_file(): | |
| try: | |
| if 'file' not in request.files: | |
| return jsonify({'error': 'No file provided'}), 400 | |
| file = request.files['file'] | |
| if file.filename == '': | |
| return jsonify({'error': 'Empty filename'}), 400 | |
| file_extension = os.path.splitext(file.filename)[1].lower() | |
| if file_extension not in ['.txt', '.pdf', '.md']: | |
| return jsonify({'error': 'Unsupported file type. Supported: .txt, .pdf, .md'}), 400 | |
| file_content = file.read() | |
| k = request.form.get('k', 10, type=int) | |
| if k < 1 or k > 100: | |
| return jsonify({'error': 'k must be between 1 and 100'}), 400 | |
| results = search_engine.search_by_file(file_content, file_extension, k=k) | |
| return jsonify({ | |
| 'filename': file.filename, | |
| 'k': k, | |
| 'results': results | |
| }) | |
| except ValueError as e: | |
| return jsonify({'error': str(e)}), 400 | |
| except Exception as e: | |
| return jsonify({'error': str(e)}), 500 | |
| if __name__ == '__main__': | |
| port = 7860 # Chuẩn cho HF Spaces | |
| app.run(host='0.0.0.0', port=port) |