Suriya
efficient model loading
25b0188
import os
import numpy as np
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
class QwenEmbeddings:
def __init__(self, model_name="Qwen/Qwen3-Embedding-8B", max_length=512, batch_size=8):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
self.model = AutoModel.from_pretrained(model_name).to(self.device)
self.max_length = max_length
self.batch_size = batch_size
def last_token_pool(self, last_hidden_states, attention_mask):
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
if left_padding:
return last_hidden_states[:, -1]
else:
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = last_hidden_states.shape[0]
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
def encode(self, texts):
all_embeddings = []
for i in tqdm(range(0, len(texts), self.batch_size)):
batch = texts[i:i + self.batch_size]
enc = self.tokenizer(
batch,
padding=True,
truncation=True,
max_length=self.max_length,
return_tensors="pt"
).to(self.device)
with torch.no_grad():
out = self.model(**enc)
pooled = self.last_token_pool(out.last_hidden_state, enc["attention_mask"])
pooled = F.normalize(pooled, p=2, dim=1)
all_embeddings.append(pooled.cpu())
torch.cuda.empty_cache()
return torch.cat(all_embeddings).numpy()
class Embedder:
def __init__(self, backend="qwen"):
"""
backend = "mini_lm" or "qwen"
"""
self.backend = backend
self.model = None
def encode_books(self, df, text_column="combined_text", batch_size=32):
if self.backend == "mini_lm":
self.model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
elif backend == "qwen":
self.model = QwenEmbeddings()
else:
raise ValueError("backend must be 'mini_lm' or 'qwen'")
texts = df[text_column].tolist()
if self.backend == "mini_lm":
embeddings = self.model.encode(texts, batch_size=batch_size, show_progress_bar=True)
return np.array(embeddings)
elif self.backend == "qwen":
return self.model.encode(texts)
def encode_query(self, text):
if self.backend == "mini_lm":
return self.model.encode([text])[0]
elif self.backend == "qwen":
emb = self.model.encode([text])
return emb[0]
def save_embeddings(self, embeddings, path="models/embeddings.npy"):
os.makedirs(os.path.dirname(path), exist_ok=True)
np.save(path, embeddings)
def load_embeddings(self, path="models/embeddings.npy"):
return np.load(path)