File size: 3,525 Bytes
9f9e23e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# -*- coding: utf-8 -*-
"""

Created on Fri Dec  5 10:25:01 2025



@author: marco.minervini

"""

# inference.py
import json
import logging
from typing import List

from sentence_transformers import SentenceTransformer

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


# 1. Load the sentence-transformers model
def model_fn(model_dir):
    """

    SageMaker calls this once when the container starts.

    model_dir is where your HF model files are on disk.

    """
    # If you bundled the HF files into the model.tar.gz, just load from model_dir:
    model = SentenceTransformer(model_dir)

    # OR, if you prefer to download by name at startup:
    # model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
    return model


# 2. Parse the incoming batch of text
def input_fn(request_body, content_type):
    """

    Turn the incoming payload into a Python object (list of texts).

    Supports:

      - text/plain: one text per line

      - application/json: {"texts": ["...", "..."]} or ["...", "..."]

    """
    if content_type == "text/plain":
        # Each line = one record; strip empties
        texts: List[str] = [l.strip() for l in request_body.splitlines() if l.strip()]
        return texts

    if content_type == "application/json":
        data = json.loads(request_body)
        if isinstance(data, dict) and "texts" in data:
            return data["texts"]
        elif isinstance(data, list):
            return data
        else:
            raise ValueError("JSON input must be a list or have a 'texts' key.")

    # Anything else is unsupported
    raise ValueError(f"Unsupported content type: {content_type}")


# 3. Run the model with per-record exception handling
def predict_fn(texts: List[str], model: SentenceTransformer):
    """

    Run embeddings with robust per-record error handling.

    We never raise inside this function, so the Batch Transform job won't crash.

    """
    results = []

    for idx, text in enumerate(texts):
        try:
            if not isinstance(text, str) or not text.strip():
                raise ValueError("Empty or non-string text.")

            # sentence-transformers encode → numpy array
            embedding = model.encode(text)

            results.append(
                {
                    "index": idx,          # position in the batch
                    "ok": True,           # success flag
                    "text": text,
                    "embedding": embedding.tolist(),
                }
            )
        except Exception as e:
            # Log for CloudWatch
            logger.warning(f"Failed to embed record {idx}: {e} | text={repr(text)}")

            # Return an error object instead of crashing
            results.append(
                {
                    "index": idx,
                    "ok": False,
                    "text": text,
                    "error": str(e),
                    "embedding": None,
                }
            )

    return results


# 4. Serialize output
def output_fn(prediction, accept):
    """

    Turn the Python object into bytes that SageMaker writes to S3.

    """
    if accept in ("application/json", "application/jsonlines", "text/json"):
        body = json.dumps(prediction)
        return body, "application/json"

    # Fallback: still return JSON
    body = json.dumps(prediction)
    return body, "application/json"