|
|
|
|
|
"""
|
|
|
Created on Fri Dec 5 10:25:01 2025
|
|
|
|
|
|
@author: marco.minervini
|
|
|
"""
|
|
|
|
|
|
|
|
|
import json
|
|
|
import logging
|
|
|
from typing import List
|
|
|
|
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
logger.setLevel(logging.INFO)
|
|
|
|
|
|
|
|
|
|
|
|
def model_fn(model_dir):
|
|
|
"""
|
|
|
SageMaker calls this once when the container starts.
|
|
|
model_dir is where your HF model files are on disk.
|
|
|
"""
|
|
|
|
|
|
model = SentenceTransformer(model_dir)
|
|
|
|
|
|
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
def input_fn(request_body, content_type):
|
|
|
"""
|
|
|
Turn the incoming payload into a Python object (list of texts).
|
|
|
Supports:
|
|
|
- text/plain: one text per line
|
|
|
- application/json: {"texts": ["...", "..."]} or ["...", "..."]
|
|
|
"""
|
|
|
if content_type == "text/plain":
|
|
|
|
|
|
texts: List[str] = [l.strip() for l in request_body.splitlines() if l.strip()]
|
|
|
return texts
|
|
|
|
|
|
if content_type == "application/json":
|
|
|
data = json.loads(request_body)
|
|
|
if isinstance(data, dict) and "texts" in data:
|
|
|
return data["texts"]
|
|
|
elif isinstance(data, list):
|
|
|
return data
|
|
|
else:
|
|
|
raise ValueError("JSON input must be a list or have a 'texts' key.")
|
|
|
|
|
|
|
|
|
raise ValueError(f"Unsupported content type: {content_type}")
|
|
|
|
|
|
|
|
|
|
|
|
def predict_fn(texts: List[str], model: SentenceTransformer):
|
|
|
"""
|
|
|
Run embeddings with robust per-record error handling.
|
|
|
We never raise inside this function, so the Batch Transform job won't crash.
|
|
|
"""
|
|
|
results = []
|
|
|
|
|
|
for idx, text in enumerate(texts):
|
|
|
try:
|
|
|
if not isinstance(text, str) or not text.strip():
|
|
|
raise ValueError("Empty or non-string text.")
|
|
|
|
|
|
|
|
|
embedding = model.encode(text)
|
|
|
|
|
|
results.append(
|
|
|
{
|
|
|
"index": idx,
|
|
|
"ok": True,
|
|
|
"text": text,
|
|
|
"embedding": embedding.tolist(),
|
|
|
}
|
|
|
)
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.warning(f"Failed to embed record {idx}: {e} | text={repr(text)}")
|
|
|
|
|
|
|
|
|
results.append(
|
|
|
{
|
|
|
"index": idx,
|
|
|
"ok": False,
|
|
|
"text": text,
|
|
|
"error": str(e),
|
|
|
"embedding": None,
|
|
|
}
|
|
|
)
|
|
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
|
|
|
def output_fn(prediction, accept):
|
|
|
"""
|
|
|
Turn the Python object into bytes that SageMaker writes to S3.
|
|
|
"""
|
|
|
if accept in ("application/json", "application/jsonlines", "text/json"):
|
|
|
body = json.dumps(prediction)
|
|
|
return body, "application/json"
|
|
|
|
|
|
|
|
|
body = json.dumps(prediction)
|
|
|
return body, "application/json"
|
|
|
|