all-MiniLM-L6-v2 / code /inference.py
mmine's picture
Upload folder using huggingface_hub
9f9e23e verified
# -*- coding: utf-8 -*-
"""
Created on Fri Dec 5 10:25:01 2025
@author: marco.minervini
"""
# inference.py
import json
import logging
from typing import List
from sentence_transformers import SentenceTransformer
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# 1. Load the sentence-transformers model
def model_fn(model_dir):
"""
SageMaker calls this once when the container starts.
model_dir is where your HF model files are on disk.
"""
# If you bundled the HF files into the model.tar.gz, just load from model_dir:
model = SentenceTransformer(model_dir)
# OR, if you prefer to download by name at startup:
# model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
return model
# 2. Parse the incoming batch of text
def input_fn(request_body, content_type):
"""
Turn the incoming payload into a Python object (list of texts).
Supports:
- text/plain: one text per line
- application/json: {"texts": ["...", "..."]} or ["...", "..."]
"""
if content_type == "text/plain":
# Each line = one record; strip empties
texts: List[str] = [l.strip() for l in request_body.splitlines() if l.strip()]
return texts
if content_type == "application/json":
data = json.loads(request_body)
if isinstance(data, dict) and "texts" in data:
return data["texts"]
elif isinstance(data, list):
return data
else:
raise ValueError("JSON input must be a list or have a 'texts' key.")
# Anything else is unsupported
raise ValueError(f"Unsupported content type: {content_type}")
# 3. Run the model with per-record exception handling
def predict_fn(texts: List[str], model: SentenceTransformer):
"""
Run embeddings with robust per-record error handling.
We never raise inside this function, so the Batch Transform job won't crash.
"""
results = []
for idx, text in enumerate(texts):
try:
if not isinstance(text, str) or not text.strip():
raise ValueError("Empty or non-string text.")
# sentence-transformers encode → numpy array
embedding = model.encode(text)
results.append(
{
"index": idx, # position in the batch
"ok": True, # success flag
"text": text,
"embedding": embedding.tolist(),
}
)
except Exception as e:
# Log for CloudWatch
logger.warning(f"Failed to embed record {idx}: {e} | text={repr(text)}")
# Return an error object instead of crashing
results.append(
{
"index": idx,
"ok": False,
"text": text,
"error": str(e),
"embedding": None,
}
)
return results
# 4. Serialize output
def output_fn(prediction, accept):
"""
Turn the Python object into bytes that SageMaker writes to S3.
"""
if accept in ("application/json", "application/jsonlines", "text/json"):
body = json.dumps(prediction)
return body, "application/json"
# Fallback: still return JSON
body = json.dumps(prediction)
return body, "application/json"