|
|
|
|
|
"""Production Sentence Embedding Model API"""
|
|
|
|
|
|
import torch
|
|
|
import json
|
|
|
import os
|
|
|
import numpy as np
|
|
|
import re
|
|
|
from typing import List, Union, Tuple, Dict
|
|
|
import time
|
|
|
|
|
|
class SentenceEmbeddingInference:
|
|
|
def __init__(self, model_dir: str):
|
|
|
self.model_dir = model_dir
|
|
|
self.model = None
|
|
|
self.vocab = None
|
|
|
self.id_to_token = None
|
|
|
self.word_pattern = re.compile(r'\b\w+\b|[.,!?;]')
|
|
|
self.load_models()
|
|
|
|
|
|
def load_models(self):
|
|
|
print("🔄 Loading sentence embedding model...")
|
|
|
|
|
|
try:
|
|
|
torchscript_path = os.path.join(self.model_dir, "exports", "model_torchscript.pt")
|
|
|
if os.path.exists(torchscript_path):
|
|
|
self.model = torch.jit.load(torchscript_path, map_location='cpu')
|
|
|
print("✅ Loaded TorchScript model")
|
|
|
else:
|
|
|
print("⚠️ TorchScript model not found")
|
|
|
return False
|
|
|
|
|
|
vocab_path = os.path.join(self.model_dir, "tokenizer", "vocab.json")
|
|
|
if os.path.exists(vocab_path):
|
|
|
with open(vocab_path, 'r', encoding='utf-8') as f:
|
|
|
self.vocab = json.load(f)
|
|
|
print(f"✅ Loaded vocabulary with {len(self.vocab)} tokens")
|
|
|
|
|
|
id_to_token_path = os.path.join(self.model_dir, "tokenizer", "id_to_token.json")
|
|
|
if os.path.exists(id_to_token_path):
|
|
|
with open(id_to_token_path, 'r', encoding='utf-8') as f:
|
|
|
id_to_token_str = json.load(f)
|
|
|
self.id_to_token = {int(k): v for k, v in id_to_token_str.items()}
|
|
|
else:
|
|
|
self.id_to_token = {v: k for k, v in self.vocab.items()}
|
|
|
|
|
|
self.model.eval()
|
|
|
print("✅ Model ready for inference")
|
|
|
return True
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f"❌ Failed to load model: {e}")
|
|
|
return False
|
|
|
|
|
|
def encode_text(self, text: str) -> List[int]:
|
|
|
if not text or not self.vocab:
|
|
|
return []
|
|
|
|
|
|
tokens = []
|
|
|
words = self.word_pattern.findall(text.lower())
|
|
|
|
|
|
for word in words:
|
|
|
word_boundary = word + "</w>"
|
|
|
if word_boundary in self.vocab:
|
|
|
tokens.append(self.vocab[word_boundary])
|
|
|
elif word in self.vocab:
|
|
|
tokens.append(self.vocab[word])
|
|
|
else:
|
|
|
for char in word:
|
|
|
if char in self.vocab:
|
|
|
tokens.append(self.vocab[char])
|
|
|
else:
|
|
|
tokens.append(self.vocab.get("[UNK]", 1))
|
|
|
|
|
|
cls_token = self.vocab.get("[CLS]", 2)
|
|
|
sep_token = self.vocab.get("[SEP]", 3)
|
|
|
|
|
|
return [cls_token] + tokens + [sep_token]
|
|
|
|
|
|
def get_embeddings(self, texts: Union[str, List[str]], batch_size: int = 8) -> np.ndarray:
|
|
|
if isinstance(texts, str):
|
|
|
texts = [texts]
|
|
|
|
|
|
if not self.model:
|
|
|
raise RuntimeError("Model not loaded.")
|
|
|
|
|
|
embeddings = []
|
|
|
|
|
|
for i in range(0, len(texts), batch_size):
|
|
|
batch_texts = texts[i:i + batch_size]
|
|
|
batch_embeddings = []
|
|
|
|
|
|
for text in batch_texts:
|
|
|
tokens = self.encode_text(text)[:128]
|
|
|
|
|
|
attention_mask = [1] * len(tokens) + [0] * (128 - len(tokens))
|
|
|
tokens = tokens + [0] * (128 - len(tokens))
|
|
|
|
|
|
input_ids = torch.tensor([tokens], dtype=torch.long)
|
|
|
attention_mask_tensor = torch.tensor([attention_mask], dtype=torch.float)
|
|
|
|
|
|
with torch.no_grad():
|
|
|
embedding = self.model(input_ids, attention_mask_tensor)
|
|
|
batch_embeddings.append(embedding.squeeze(0).cpu().numpy())
|
|
|
|
|
|
embeddings.extend(batch_embeddings)
|
|
|
|
|
|
return np.array(embeddings)
|
|
|
|
|
|
def compute_similarity(self, text1: str, text2: str) -> float:
|
|
|
embeddings = self.get_embeddings([text1, text2])
|
|
|
|
|
|
emb1 = embeddings[0] / (np.linalg.norm(embeddings[0]) + 1e-8)
|
|
|
emb2 = embeddings[1] / (np.linalg.norm(embeddings[1]) + 1e-8)
|
|
|
|
|
|
similarity = np.dot(emb1, emb2)
|
|
|
return float(np.clip(similarity, -1.0, 1.0))
|
|
|
|
|
|
def find_similar_texts(self, query: str, candidates: List[str], top_k: int = 5) -> List[Tuple[str, float]]:
|
|
|
if not candidates:
|
|
|
return []
|
|
|
|
|
|
query_embedding = self.get_embeddings([query])[0]
|
|
|
query_norm = query_embedding / (np.linalg.norm(query_embedding) + 1e-8)
|
|
|
|
|
|
candidate_embeddings = self.get_embeddings(candidates)
|
|
|
|
|
|
similarities = []
|
|
|
for i, candidate_emb in enumerate(candidate_embeddings):
|
|
|
candidate_norm = candidate_emb / (np.linalg.norm(candidate_emb) + 1e-8)
|
|
|
similarity = np.dot(query_norm, candidate_norm)
|
|
|
similarities.append((candidates[i], float(similarity)))
|
|
|
|
|
|
similarities.sort(key=lambda x: x[1], reverse=True)
|
|
|
return similarities[:top_k]
|
|
|
|
|
|
def benchmark_performance(self, num_texts: int = 100) -> Dict[str, float]:
|
|
|
print(f"🚀 Benchmarking performance with {num_texts} texts...")
|
|
|
|
|
|
test_texts = [f"This is test sentence number {i} for benchmarking performance." for i in range(num_texts)]
|
|
|
|
|
|
start_time = time.time()
|
|
|
embeddings = self.get_embeddings(test_texts)
|
|
|
end_time = time.time()
|
|
|
|
|
|
total_time = end_time - start_time
|
|
|
texts_per_second = num_texts / total_time
|
|
|
avg_time_per_text = total_time / num_texts * 1000
|
|
|
|
|
|
embedding_memory_mb = embeddings.nbytes / (1024 * 1024)
|
|
|
|
|
|
results = {
|
|
|
'texts_per_second': texts_per_second,
|
|
|
'avg_time_per_text_ms': avg_time_per_text,
|
|
|
'total_time_seconds': total_time,
|
|
|
'embedding_memory_mb': embedding_memory_mb,
|
|
|
'embedding_dimensions': embeddings.shape[1]
|
|
|
}
|
|
|
|
|
|
print(f"📊 Benchmark Results:")
|
|
|
print(f" Texts per second: {texts_per_second:.1f}")
|
|
|
print(f" Average time per text: {avg_time_per_text:.2f}ms")
|
|
|
print(f" Embedding dimensions: {embeddings.shape[1]}")
|
|
|
print(f" Memory usage: {embedding_memory_mb:.2f}MB")
|
|
|
|
|
|
return results
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
model = SentenceEmbeddingInference("./")
|
|
|
|
|
|
if model.model is None:
|
|
|
print("❌ Failed to load model. Exiting.")
|
|
|
exit(1)
|
|
|
|
|
|
test_sentences = [
|
|
|
"The cat sat on the mat.",
|
|
|
"A feline rested on the rug.",
|
|
|
"Dogs are loyal companions.",
|
|
|
"Programming requires logical thinking.",
|
|
|
"Machine learning transforms data into insights.",
|
|
|
"Natural language processing helps computers understand text."
|
|
|
]
|
|
|
|
|
|
print("\n🧪 Testing sentence embeddings...")
|
|
|
|
|
|
embeddings = model.get_embeddings(test_sentences)
|
|
|
print(f"Generated embeddings shape: {embeddings.shape}")
|
|
|
|
|
|
similarity = model.compute_similarity(test_sentences[0], test_sentences[1])
|
|
|
print(f"\nSimilarity between:")
|
|
|
print(f" '{test_sentences[0]}'")
|
|
|
print(f" '{test_sentences[1]}'")
|
|
|
print(f" Similarity: {similarity:.4f}")
|
|
|
|
|
|
query = "What are cats like?"
|
|
|
similar_texts = model.find_similar_texts(query, test_sentences, top_k=3)
|
|
|
print(f"\nMost similar to '{query}':")
|
|
|
for text, score in similar_texts:
|
|
|
print(f" {score:.4f}: {text}")
|
|
|
|
|
|
print("\n" + "="*50)
|
|
|
benchmark_results = model.benchmark_performance(50)
|
|
|
|
|
|
print("\n✅ Model testing completed successfully!")
|
|
|
|