File size: 6,816 Bytes
fff5f84
 
 
 
 
 
 
 
 
d43ecac
fff5f84
 
 
d43ecac
 
 
 
 
 
 
ff255f8
d43ecac
 
 
 
 
 
 
 
 
 
 
fff5f84
 
 
 
 
d43ecac
fff5f84
 
 
 
d43ecac
fff5f84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d43ecac
fff5f84
 
 
d43ecac
fff5f84
d43ecac
fff5f84
 
d43ecac
 
fff5f84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d43ecac
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import os
import numpy as np
import h5py
import hnswlib
from flask import Flask, request, jsonify
from flask_cors import CORS
from sentence_transformers import SentenceTransformer
import PyPDF2
import io
from huggingface_hub import hf_hub_download

app = Flask(__name__)
CORS(app, origins=['*'])

print("\n" + "="*50)
print("📥 INITIALIZING PAPER SERVER...")
print("="*50)

# Cấu hình Dataset
HF_TOKEN = os.environ.get("HF_TOKEN")
DATASET_ID = "huynguyen6906/Paper_server_data"  # Thay bằng dataset ID của bạn cho paper nếu khác

# Tải file từ Hugging Face Dataset
try:
    print(f"Downloading data from {DATASET_ID}...")
    H5_FILE_PATH = hf_hub_download(repo_id=DATASET_ID, filename="Papers_Embedbed_0-1000000.h5", repo_type="dataset", token=HF_TOKEN)
    BIN_FILE_PATH = hf_hub_download(repo_id=DATASET_ID, filename="hnsw_paper_index.bin", repo_type="dataset", token=HF_TOKEN)
    print(f"✅ Data loaded: {H5_FILE_PATH}")
except Exception as e:
    print(f"❌ Error downloading data: {str(e)}")
    H5_FILE_PATH = 'Papers_Embedbed_0-1000000.h5'
    BIN_FILE_PATH = 'hnsw_paper_index.bin'

class PaperSearchEngine:
    def __init__(self, h5_file_path=H5_FILE_PATH):
        print("Initializing Paper Search Engine...")

        # Load Sentence Transformer model
        print("Loading Sentence Transformer model (all-roberta-large-v1)...")
        self.model = SentenceTransformer('sentence-transformers/all-roberta-large-v1')
        print("Model loaded successfully!")

        # Check if .h5 file exists
        if not os.path.exists(h5_file_path):
            print(f"❌ Error: {h5_file_path} not found!")
            raise FileNotFoundError(f"Required file not found: {h5_file_path}")

        # Load embeddings and URLs from HDF5
        print(f"Loading embeddings from {h5_file_path}...")
        self.paper = h5py.File(h5_file_path, 'r')

        print(f"Loaded {len(self.paper['urls'])} paper embeddings")
        print(f"Embedding dimension: {self.paper['embeddings'].shape[1]}")

        dim = self.paper["embeddings"].shape[1]
        max_elements = len(self.paper["embeddings"])

        # Check if .bin file exists for faster loading
        if os.path.exists(BIN_FILE_PATH):
            print(f"⚡ Loading HNSW index from {BIN_FILE_PATH}...")
            self.index = hnswlib.Index(space='cosine', dim=dim)
            self.index.load_index(BIN_FILE_PATH, max_elements=max_elements)
            self.index.set_ef(200)
            print("✅ HNSW index loaded!")
        else:
            # Build HNSW index from scratch
            print("Building HNSW index from scratch...")
            self.index = hnswlib.Index(space='cosine', dim=dim)
            self.index.init_index(max_elements=max_elements, ef_construction=400, M=200)
            self.index.add_items(self.paper["embeddings"])
            self.index.set_ef(200)
            self.index.save_index(BIN_FILE_PATH)
            print(f"💾 Saved HNSW index to: {BIN_FILE_PATH}")

        print("Paper Search Engine ready!")

    def text_to_vector(self, text):
        embedding = self.model.encode([text], convert_to_numpy=True, normalize_embeddings=True)
        return embedding[0]

    def extract_text_from_file(self, file_content, file_extension):
        if file_extension in ['.txt', '.md']:
            try:
                return file_content.decode('utf-8')
            except UnicodeDecodeError:
                return file_content.decode('latin-1')
        elif file_extension == '.pdf':
            try:
                pdf_reader = PyPDF2.PdfReader(io.BytesIO(file_content))
                text = ""
                for page in pdf_reader.pages:
                    text += page.extract_text() + "\n"
                return text.strip()
            except Exception as e:
                raise ValueError(f"Error extracting text from PDF: {str(e)}")
        else:
            raise ValueError(f"Unsupported file type: {file_extension}")

    def search(self, query_text, k=10):
        query_vector = self.text_to_vector(query_text)
        labels, distances = self.index.knn_query(query_vector, k=k)
        similarities = 1 - distances[0]
        results = []
        for idx, similarity in zip(labels[0], similarities):
            results.append({
                'url': self.paper["urls"][idx].decode('utf-8'),
                'similarity': float(similarity)
            })
        return results

    def search_by_file(self, file_content, file_extension, k=10):
        text = self.extract_text_from_file(file_content, file_extension)
        return self.search(text, k)

search_engine = PaperSearchEngine(h5_file_path=H5_FILE_PATH)

@app.route('/health', methods=['GET'])
def health_check():
    return jsonify({
        'status': 'healthy',
        'service': 'paper-search-engine',
        'total_papers': len(search_engine.paper["urls"]),
        'embedding_dim': search_engine.paper["embeddings"].shape[1],
        'model': 'all-roberta-large-v1'
    })

@app.route('/search', methods=['POST'])
def search_text():
    try:
        data = request.get_json()
        if not data or 'query' not in data:
            return jsonify({'error': 'Missing query parameter'}), 400
        query = data['query']
        k = data.get('k', 10)
        if not isinstance(k, int) or k < 1 or k > 100:
            return jsonify({'error': 'k must be an integer between 1 and 100'}), 400
        results = search_engine.search(query, k=k)
        return jsonify({
            'query': query,
            'k': k,
            'results': results
        })
    except Exception as e:
        return jsonify({'error': str(e)}), 500

@app.route('/search/file', methods=['POST'])
def search_file():
    try:
        if 'file' not in request.files:
            return jsonify({'error': 'No file provided'}), 400
        file = request.files['file']
        if file.filename == '':
            return jsonify({'error': 'Empty filename'}), 400
        file_extension = os.path.splitext(file.filename)[1].lower()
        if file_extension not in ['.txt', '.pdf', '.md']:
            return jsonify({'error': 'Unsupported file type. Supported: .txt, .pdf, .md'}), 400
        file_content = file.read()
        k = request.form.get('k', 10, type=int)
        if k < 1 or k > 100:
            return jsonify({'error': 'k must be between 1 and 100'}), 400
        results = search_engine.search_by_file(file_content, file_extension, k=k)
        return jsonify({
            'filename': file.filename,
            'k': k,
            'results': results
        })
    except ValueError as e:
        return jsonify({'error': str(e)}), 400
    except Exception as e:
        return jsonify({'error': str(e)}), 500

if __name__ == '__main__':
    port = 7860  # Chuẩn cho HF Spaces
    app.run(host='0.0.0.0', port=port)