File size: 1,697 Bytes
ab6d6ab | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 | import torch
from sentence_transformers import SentenceTransformer
import os
try:
# Load the local sentence-transformers model
model_path = "."
print(f"Loading model from: {model_path}")
# Force model to use CPU for TorchScript compatibility
model = SentenceTransformer(model_path, device='cpu')
# Set model to evaluation mode
model.eval()
# Create an example input
sample_text = "This is an example sentence to encode."
print(f"Creating example input with text: '{sample_text}'")
# Get the tokenizer from the model
tokenizer = model.tokenizer
# Prepare inputs (max_length is 256 for this model)
inputs = tokenizer(sample_text, return_tensors="pt", padding=True, truncation=True, max_length=256)
print("Tracing model - this may take a moment...")
# Trace the model with strict=False to handle dictionary outputs
with torch.no_grad():
# For sentence transformers, we need to trace the entire encoding pipeline
traced_model = torch.jit.trace(
model[0].auto_model, # Access the underlying BERT model
(inputs["input_ids"], inputs["attention_mask"]),
strict=False
)
# Save the traced model
output_path = "model.pt"
traced_model.save(output_path)
print(f"Model successfully converted to TorchScript and saved as {output_path}")
print(f"Full path: {os.path.abspath(output_path)}")
print(f"Note: This traces only the transformer model. Pooling and normalization layers are not included.")
except Exception as e:
print(f"Error converting model: {str(e)}")
import traceback
traceback.print_exc()
|