code / sagemaker /inference.py
24122168-collab
Add application file
6ba100e
"""
inference.py β€” SageMaker Entry Point | Email Gatekeeper | Phase 1
========================================================================
Think of this file like a circuit board with 4 connectors.
SageMaker plugs into each one in order, every time a request arrives:
[1] model_fn β†’ Power-on. Runs ONCE when the server starts.
[2] input_fn β†’ Input pin. Reads the raw HTTP request bytes.
[3] predict_fn β†’ Logic gate. Runs your classifier, scores the result.
[4] output_fn β†’ Output pin. Sends the JSON response back.
Your classifier lives in classifier.py (same folder).
No GPU, no heavy ML libraries needed β€” pure Python logic.
"""
import json
import os
import uuid
import logging
from datetime import datetime, timezone
# classifier.py must be in the same folder as this file
from classifier import classify, decode, extract_features
# SageMaker streams all logger.info() calls to CloudWatch Logs automatically
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# ── Reward weights (must match environment.py exactly) ────────────────────────
# These are the scores your RL agent learned against.
# They are used here only for logging β€” not for routing decisions.
_REWARDS = {
"EXACT": 1.0, # all 3 dimensions correct
"PARTIAL_1": 0.2, # urgency correct, 1 other dimension wrong
"PARTIAL_2": 0.1, # urgency correct, both other dimensions wrong
"SECURITY_MISS": -2.0, # security email but urgency was NOT flagged as 2
"WRONG": 0.0, # urgency wrong on a non-security email
}
# ── SLA table: urgency code β†’ response time target ───────────────────────────
_SLA = {
0: {"priority": "P3", "respond_within_minutes": 1440}, # General β€” 24 h
1: {"priority": "P2", "respond_within_minutes": 240}, # Billing β€” 4 h
2: {"priority": "P1", "respond_within_minutes": 15}, # Security β€” 15 min
}
# ─────────────────────────────────────────────────────────────────────────────
# HELPER: Partial-match scorer
# ─────────────────────────────────────────────────────────────────────────────
def _score_match(predicted: tuple, ground_truth: dict) -> dict:
"""
Compare the 3 predicted dimensions against the known correct answer.
Only called when the request includes a "ground_truth" field.
Useful for:
- Offline evaluation / batch testing
- Logging accuracy metrics to CloudWatch
Returns a dict with:
status β€” one of EXACT / PARTIAL_1 / PARTIAL_2 / SECURITY_MISS / WRONG
reward β€” float score matching your RL reward function
wrong_fields β€” list of dimension names that were predicted incorrectly
"""
p_urgency, p_routing, p_resolution = predicted
g_urgency = int(ground_truth["urgency"])
g_routing = int(ground_truth["routing"])
g_resolution = int(ground_truth["resolution"])
# Which of the 3 dimensions are correct?
correct = {
"urgency": p_urgency == g_urgency,
"routing": p_routing == g_routing,
"resolution": p_resolution == g_resolution,
}
wrong = [dim for dim, ok in correct.items() if not ok]
# ── Decision tree (same priority order as environment.py) ─────────────────
# Rule 1: Security email that was NOT flagged as security β†’ worst outcome
if g_urgency == 2 and p_urgency != 2:
status = "SECURITY_MISS"
# Rule 2: All 3 correct β†’ perfect
elif not wrong:
status = "EXACT"
# Rule 3: Urgency correct but 1 other dimension wrong β†’ partial credit
elif correct["urgency"] and len(wrong) == 1:
status = "PARTIAL_1"
# Rule 4: Urgency correct but both other dimensions wrong β†’ small credit
elif correct["urgency"] and len(wrong) == 2:
status = "PARTIAL_2"
# Rule 5: Urgency itself wrong β†’ no credit
else:
status = "WRONG"
logger.info(
"MATCH_EVAL | status=%s reward=%.1f wrong_fields=%s",
status, _REWARDS[status], wrong
)
return {
"status": status,
"reward": _REWARDS[status],
"correct_dims": correct,
"wrong_fields": wrong,
}
# ─────────────────────────────────────────────────────────────────────────────
# [1] model_fn β€” Power-on. Runs ONCE at container start.
# ─────────────────────────────────────────────────────────────────────────────
def model_fn(model_dir: str) -> dict:
"""
SageMaker calls this once when the container boots up.
model_dir is the folder where SageMaker unpacks your model.tar.gz.
For a rule-based classifier there are no weights to load.
We just return a config dict that predict_fn will use.
"""
logger.info("model_fn | model_dir=%s", model_dir)
# Optional: load a config.json from your model.tar.gz to override defaults
# at runtime without redeploying (e.g. change SLA targets).
config_path = os.path.join(model_dir, "config.json")
config = {}
if os.path.exists(config_path):
with open(config_path) as f:
config = json.load(f)
logger.info("Config loaded: %s", config)
model = {
"version": config.get("version", "1.0.0"),
"sla": config.get("sla", _SLA),
# SageMaker injects the endpoint name as an env var
"endpoint_name": os.environ.get("SAGEMAKER_ENDPOINT_NAME", "local"),
}
logger.info("Model ready | version=%s endpoint=%s",
model["version"], model["endpoint_name"])
return model
# ─────────────────────────────────────────────────────────────────────────────
# [2] input_fn β€” Input pin. Deserialise the raw HTTP request.
# ─────────────────────────────────────────────────────────────────────────────
def input_fn(request_body: str | bytes, content_type: str) -> dict:
"""
Converts the raw bytes from the HTTP POST body into a Python dict.
Accepted request formats:
Format A β€” JSON with raw email text (most common):
{
"subject": "Your account was hacked",
"body": "We detected unauthorized access..."
}
Format B β€” JSON with pre-extracted features (faster, skips NLP):
{
"keywords": ["hacked", "password"],
"sentiment": "negative",
"context": "security"
}
Format C β€” Add ground_truth to either format above for accuracy scoring:
{
"subject": "...",
"body": "...",
"ground_truth": {"urgency": 2, "routing": 1, "resolution": 2}
}
"""
logger.info("input_fn | content_type=%s", content_type)
ct = content_type.lower().split(";")[0].strip()
if ct == "application/json":
if isinstance(request_body, bytes):
request_body = request_body.decode("utf-8")
payload = json.loads(request_body)
elif ct == "text/plain":
# Accept raw email text directly β€” treat entire body as email body
text = request_body.decode("utf-8") if isinstance(request_body, bytes) else request_body
payload = {"subject": "", "body": text}
else:
raise ValueError(
f"Unsupported content type: '{content_type}'. "
"Send 'application/json' or 'text/plain'."
)
# Must have at least something to classify
if not any([payload.get("subject"), payload.get("body"),
payload.get("keywords"), payload.get("context")]):
raise ValueError(
"Request must include 'subject', 'body', 'keywords', or 'context'."
)
return payload
# ─────────────────────────────────────────────────────────────────────────────
# [3] predict_fn β€” Logic gate. Run the classifier.
# ─────────────────────────────────────────────────────────────────────────────
def predict_fn(data: dict, model: dict) -> dict:
"""
The main classification step. Runs on every request.
Step 1: Extract features from raw text (or use pre-supplied features)
Step 2: Classify β†’ 3 integer codes (urgency, routing, resolution)
Step 3: Decode codes β†’ human labels ("Security Breach", "Escalate", ...)
Step 4: Score against ground_truth (only if ground_truth is in request)
Step 5: Return everything as a dict (output_fn will format it as JSON)
"""
logger.info("predict_fn | keys=%s", list(data.keys()))
# ── Step 1: Feature extraction ────────────────────────────────────────────
# Fast path: caller already extracted features
if data.get("context"):
features = {
"keywords": data.get("keywords", []),
"sentiment": data.get("sentiment", "neutral"),
"context": data["context"],
}
# NLP path: extract from raw subject + body text
else:
features = extract_features(
subject=data.get("subject", ""),
body=data.get("body", ""),
)
# ── Step 2: Classify β†’ 3 codes ────────────────────────────────────────────
urgency, routing, resolution = classify(features)
# ── Step 3: Decode to human-readable labels ───────────────────────────────
labels = decode(urgency, routing, resolution)
logger.info(
"CLASSIFIED | category=%s dept=%s action=%s | context=%s keywords=%s",
labels["urgency"], labels["routing"], labels["resolution"],
features["context"], features["keywords"],
)
# ── Step 4: Score against ground_truth (optional) ─────────────────────────
ground_truth = data.get("ground_truth")
if ground_truth:
match = _score_match((urgency, routing, resolution), ground_truth)
else:
# No ground_truth supplied β€” this is a live production request
match = {"status": "UNVERIFIED", "reward": None,
"correct_dims": {}, "wrong_fields": []}
# ── Step 5: Return raw prediction dict ────────────────────────────────────
return {
"urgency_code": urgency,
"routing_code": routing,
"resolution_code": resolution,
"labels": labels,
"features": features,
"match": match,
"sla": model["sla"][urgency],
"endpoint": model["endpoint_name"],
}
# ─────────────────────────────────────────────────────────────────────────────
# [4] output_fn β€” Output pin. Format and send the response.
# ─────────────────────────────────────────────────────────────────────────────
def output_fn(prediction: dict, accept: str) -> tuple[str, str]:
"""
Converts the prediction dict into the final HTTP response body.
Default response format: application/json
Optional CSV format: text/csv (useful for batch jobs writing to S3)
JSON response shape:
{
"request_id": "uuid",
"timestamp": "2024-01-15T10:30:00Z",
"triage": {
"category": "Security Breach", ← urgency label
"department": "Tech Support", ← routing label
"action": "Escalate" ← resolution label
},
"codes": {
"urgency": 2, "routing": 1, "resolution": 2
},
"match_result": {
"status": "EXACT", ← or PARTIAL_1 / PARTIAL_2 / SECURITY_MISS / WRONG
"reward": 1.0, ← RL reward score
"wrong_fields": [] ← which dimensions were wrong
},
"sla": {
"priority": "P1",
"respond_within_minutes": 15
}
}
"""
accept_type = (accept or "application/json").lower().split(";")[0].strip()
response = {
"request_id": str(uuid.uuid4()),
"timestamp": datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ"),
"triage": {
"category": prediction["labels"]["urgency"],
"department": prediction["labels"]["routing"],
"action": prediction["labels"]["resolution"],
},
"codes": {
"urgency": prediction["urgency_code"],
"routing": prediction["routing_code"],
"resolution": prediction["resolution_code"],
},
"features": {
"keywords": prediction["features"]["keywords"],
"sentiment": prediction["features"]["sentiment"],
"context": prediction["features"]["context"],
},
"match_result": {
"status": prediction["match"]["status"],
"reward": prediction["match"]["reward"],
"wrong_fields": prediction["match"]["wrong_fields"],
},
"sla": prediction["sla"],
}
# ── CSV output (for SageMaker Batch Transform jobs) ───────────────────────
if accept_type == "text/csv":
row = ",".join([
response["request_id"],
response["triage"]["category"],
response["triage"]["department"],
response["triage"]["action"],
str(response["codes"]["urgency"]),
str(response["codes"]["routing"]),
str(response["codes"]["resolution"]),
str(response["match_result"]["status"]),
str(response["match_result"]["reward"] or ""),
response["sla"]["priority"],
])
return row, "text/csv"
return json.dumps(response, ensure_ascii=False), "application/json"