# -*- coding: utf-8 -*- """ Created on Fri Dec 5 10:25:01 2025 @author: marco.minervini """ # inference.py import json import logging from typing import List from sentence_transformers import SentenceTransformer logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) # 1. Load the sentence-transformers model def model_fn(model_dir): """ SageMaker calls this once when the container starts. model_dir is where your HF model files are on disk. """ # If you bundled the HF files into the model.tar.gz, just load from model_dir: model = SentenceTransformer(model_dir) # OR, if you prefer to download by name at startup: # model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") return model # 2. Parse the incoming batch of text def input_fn(request_body, content_type): """ Turn the incoming payload into a Python object (list of texts). Supports: - text/plain: one text per line - application/json: {"texts": ["...", "..."]} or ["...", "..."] """ if content_type == "text/plain": # Each line = one record; strip empties texts: List[str] = [l.strip() for l in request_body.splitlines() if l.strip()] return texts if content_type == "application/json": data = json.loads(request_body) if isinstance(data, dict) and "texts" in data: return data["texts"] elif isinstance(data, list): return data else: raise ValueError("JSON input must be a list or have a 'texts' key.") # Anything else is unsupported raise ValueError(f"Unsupported content type: {content_type}") # 3. Run the model with per-record exception handling def predict_fn(texts: List[str], model: SentenceTransformer): """ Run embeddings with robust per-record error handling. We never raise inside this function, so the Batch Transform job won't crash. """ results = [] for idx, text in enumerate(texts): try: if not isinstance(text, str) or not text.strip(): raise ValueError("Empty or non-string text.") # sentence-transformers encode → numpy array embedding = model.encode(text) results.append( { "index": idx, # position in the batch "ok": True, # success flag "text": text, "embedding": embedding.tolist(), } ) except Exception as e: # Log for CloudWatch logger.warning(f"Failed to embed record {idx}: {e} | text={repr(text)}") # Return an error object instead of crashing results.append( { "index": idx, "ok": False, "text": text, "error": str(e), "embedding": None, } ) return results # 4. Serialize output def output_fn(prediction, accept): """ Turn the Python object into bytes that SageMaker writes to S3. """ if accept in ("application/json", "application/jsonlines", "text/json"): body = json.dumps(prediction) return body, "application/json" # Fallback: still return JSON body = json.dumps(prediction) return body, "application/json"