Spaces:
Running
Running
File size: 6,816 Bytes
fff5f84 d43ecac fff5f84 d43ecac ff255f8 d43ecac fff5f84 d43ecac fff5f84 d43ecac fff5f84 d43ecac fff5f84 d43ecac fff5f84 d43ecac fff5f84 d43ecac fff5f84 d43ecac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
import os
import numpy as np
import h5py
import hnswlib
from flask import Flask, request, jsonify
from flask_cors import CORS
from sentence_transformers import SentenceTransformer
import PyPDF2
import io
from huggingface_hub import hf_hub_download
app = Flask(__name__)
CORS(app, origins=['*'])
print("\n" + "="*50)
print("📥 INITIALIZING PAPER SERVER...")
print("="*50)
# Cấu hình Dataset
HF_TOKEN = os.environ.get("HF_TOKEN")
DATASET_ID = "huynguyen6906/Paper_server_data" # Thay bằng dataset ID của bạn cho paper nếu khác
# Tải file từ Hugging Face Dataset
try:
print(f"Downloading data from {DATASET_ID}...")
H5_FILE_PATH = hf_hub_download(repo_id=DATASET_ID, filename="Papers_Embedbed_0-1000000.h5", repo_type="dataset", token=HF_TOKEN)
BIN_FILE_PATH = hf_hub_download(repo_id=DATASET_ID, filename="hnsw_paper_index.bin", repo_type="dataset", token=HF_TOKEN)
print(f"✅ Data loaded: {H5_FILE_PATH}")
except Exception as e:
print(f"❌ Error downloading data: {str(e)}")
H5_FILE_PATH = 'Papers_Embedbed_0-1000000.h5'
BIN_FILE_PATH = 'hnsw_paper_index.bin'
class PaperSearchEngine:
def __init__(self, h5_file_path=H5_FILE_PATH):
print("Initializing Paper Search Engine...")
# Load Sentence Transformer model
print("Loading Sentence Transformer model (all-roberta-large-v1)...")
self.model = SentenceTransformer('sentence-transformers/all-roberta-large-v1')
print("Model loaded successfully!")
# Check if .h5 file exists
if not os.path.exists(h5_file_path):
print(f"❌ Error: {h5_file_path} not found!")
raise FileNotFoundError(f"Required file not found: {h5_file_path}")
# Load embeddings and URLs from HDF5
print(f"Loading embeddings from {h5_file_path}...")
self.paper = h5py.File(h5_file_path, 'r')
print(f"Loaded {len(self.paper['urls'])} paper embeddings")
print(f"Embedding dimension: {self.paper['embeddings'].shape[1]}")
dim = self.paper["embeddings"].shape[1]
max_elements = len(self.paper["embeddings"])
# Check if .bin file exists for faster loading
if os.path.exists(BIN_FILE_PATH):
print(f"⚡ Loading HNSW index from {BIN_FILE_PATH}...")
self.index = hnswlib.Index(space='cosine', dim=dim)
self.index.load_index(BIN_FILE_PATH, max_elements=max_elements)
self.index.set_ef(200)
print("✅ HNSW index loaded!")
else:
# Build HNSW index from scratch
print("Building HNSW index from scratch...")
self.index = hnswlib.Index(space='cosine', dim=dim)
self.index.init_index(max_elements=max_elements, ef_construction=400, M=200)
self.index.add_items(self.paper["embeddings"])
self.index.set_ef(200)
self.index.save_index(BIN_FILE_PATH)
print(f"💾 Saved HNSW index to: {BIN_FILE_PATH}")
print("Paper Search Engine ready!")
def text_to_vector(self, text):
embedding = self.model.encode([text], convert_to_numpy=True, normalize_embeddings=True)
return embedding[0]
def extract_text_from_file(self, file_content, file_extension):
if file_extension in ['.txt', '.md']:
try:
return file_content.decode('utf-8')
except UnicodeDecodeError:
return file_content.decode('latin-1')
elif file_extension == '.pdf':
try:
pdf_reader = PyPDF2.PdfReader(io.BytesIO(file_content))
text = ""
for page in pdf_reader.pages:
text += page.extract_text() + "\n"
return text.strip()
except Exception as e:
raise ValueError(f"Error extracting text from PDF: {str(e)}")
else:
raise ValueError(f"Unsupported file type: {file_extension}")
def search(self, query_text, k=10):
query_vector = self.text_to_vector(query_text)
labels, distances = self.index.knn_query(query_vector, k=k)
similarities = 1 - distances[0]
results = []
for idx, similarity in zip(labels[0], similarities):
results.append({
'url': self.paper["urls"][idx].decode('utf-8'),
'similarity': float(similarity)
})
return results
def search_by_file(self, file_content, file_extension, k=10):
text = self.extract_text_from_file(file_content, file_extension)
return self.search(text, k)
search_engine = PaperSearchEngine(h5_file_path=H5_FILE_PATH)
@app.route('/health', methods=['GET'])
def health_check():
return jsonify({
'status': 'healthy',
'service': 'paper-search-engine',
'total_papers': len(search_engine.paper["urls"]),
'embedding_dim': search_engine.paper["embeddings"].shape[1],
'model': 'all-roberta-large-v1'
})
@app.route('/search', methods=['POST'])
def search_text():
try:
data = request.get_json()
if not data or 'query' not in data:
return jsonify({'error': 'Missing query parameter'}), 400
query = data['query']
k = data.get('k', 10)
if not isinstance(k, int) or k < 1 or k > 100:
return jsonify({'error': 'k must be an integer between 1 and 100'}), 400
results = search_engine.search(query, k=k)
return jsonify({
'query': query,
'k': k,
'results': results
})
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/search/file', methods=['POST'])
def search_file():
try:
if 'file' not in request.files:
return jsonify({'error': 'No file provided'}), 400
file = request.files['file']
if file.filename == '':
return jsonify({'error': 'Empty filename'}), 400
file_extension = os.path.splitext(file.filename)[1].lower()
if file_extension not in ['.txt', '.pdf', '.md']:
return jsonify({'error': 'Unsupported file type. Supported: .txt, .pdf, .md'}), 400
file_content = file.read()
k = request.form.get('k', 10, type=int)
if k < 1 or k > 100:
return jsonify({'error': 'k must be between 1 and 100'}), 400
results = search_engine.search_by_file(file_content, file_extension, k=k)
return jsonify({
'filename': file.filename,
'k': k,
'results': results
})
except ValueError as e:
return jsonify({'error': str(e)}), 400
except Exception as e:
return jsonify({'error': str(e)}), 500
if __name__ == '__main__':
port = 7860 # Chuẩn cho HF Spaces
app.run(host='0.0.0.0', port=port) |