File size: 1,697 Bytes
ab6d6ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
from sentence_transformers import SentenceTransformer
import os

try:
    # Load the local sentence-transformers model
    model_path = "."
    print(f"Loading model from: {model_path}")
    
    # Force model to use CPU for TorchScript compatibility
    model = SentenceTransformer(model_path, device='cpu')
    
    # Set model to evaluation mode
    model.eval()
    
    # Create an example input
    sample_text = "This is an example sentence to encode."
    print(f"Creating example input with text: '{sample_text}'")
    
    # Get the tokenizer from the model
    tokenizer = model.tokenizer
    
    # Prepare inputs (max_length is 256 for this model)
    inputs = tokenizer(sample_text, return_tensors="pt", padding=True, truncation=True, max_length=256)
    
    print("Tracing model - this may take a moment...")
    # Trace the model with strict=False to handle dictionary outputs
    with torch.no_grad():
        # For sentence transformers, we need to trace the entire encoding pipeline
        traced_model = torch.jit.trace(
            model[0].auto_model,  # Access the underlying BERT model
            (inputs["input_ids"], inputs["attention_mask"]),
            strict=False
        )
    
    # Save the traced model
    output_path = "model.pt"
    traced_model.save(output_path)
    
    print(f"Model successfully converted to TorchScript and saved as {output_path}")
    print(f"Full path: {os.path.abspath(output_path)}")
    print(f"Note: This traces only the transformer model. Pooling and normalization layers are not included.")
    
except Exception as e:
    print(f"Error converting model: {str(e)}")
    import traceback
    traceback.print_exc()