File size: 2,598 Bytes
2fb9f86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# Usage Examples

## 1. Basic Example
```python
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity

model = SentenceTransformer('ThanhLe0125/ebd-math')

# Single query example
query = "query: Cách tính đạo hàm của hàm số"
chunks = [
    "passage: Đạo hàm của hàm số f(x) tại điểm x0 được định nghĩa...",
    "passage: Các quy tắc tính đạo hàm cơ bản: (x^n)' = nx^(n-1)...",
    "passage: Phương trình tích phân là phương trình chứa hàm số..."
]

query_emb = model.encode([query])
chunk_embs = model.encode(chunks)
similarities = cosine_similarity(query_emb, chunk_embs)[0]

print("Rankings:")
for i, sim in enumerate(similarities):
    print(f"Chunk {i+1}: {sim:.4f}")
```

## 2. Batch Processing
```python
queries = [
    "query: Định nghĩa hàm số đồng biến",
    "query: Cách giải phương trình bậc hai",
    "query: Công thức tính thể tích hình cầu"
]

# Encode all at once for efficiency
query_embs = model.encode(queries)
chunk_embs = model.encode(chunks)

# Calculate similarities for all queries
for i, query in enumerate(queries):
    similarities = cosine_similarity([query_embs[i]], chunk_embs)[0]
    best_idx = similarities.argmax()
    print(f"Best match for '{query}': {chunks[best_idx]} (score: {similarities[best_idx]:.4f})")
```

## 3. Production Usage
```python
class MathRetriever:
    def __init__(self, model_name='ThanhLe0125/ebd-math'):
        self.model = SentenceTransformer(model_name)
    
    def retrieve(self, query, chunks, top_k=5):
        # Format inputs
        formatted_query = f"query: {query}" if not query.startswith("query:") else query
        formatted_chunks = [f"passage: {chunk}" if not chunk.startswith("passage:") else chunk 
                          for chunk in chunks]
        
        # Encode and rank
        query_emb = self.model.encode([formatted_query])
        chunk_embs = self.model.encode(formatted_chunks)
        similarities = cosine_similarity(query_emb, chunk_embs)[0]
        
        # Get top K results
        top_indices = similarities.argsort()[::-1][:top_k]
        results = [
            {
                'chunk': chunks[i],
                'similarity': float(similarities[i]),
                'rank': rank + 1
            }
            for rank, i in enumerate(top_indices)
        ]
        
        return results

# Usage
retriever = MathRetriever()
results = retriever.retrieve(
    "Định nghĩa hàm số liên tục", 
    mathematical_chunks, 
    top_k=3
)
```