| """ |
| inference.py β SageMaker Entry Point | Email Gatekeeper | Phase 1 |
| ======================================================================== |
| |
| Think of this file like a circuit board with 4 connectors. |
| SageMaker plugs into each one in order, every time a request arrives: |
| |
| [1] model_fn β Power-on. Runs ONCE when the server starts. |
| [2] input_fn β Input pin. Reads the raw HTTP request bytes. |
| [3] predict_fn β Logic gate. Runs your classifier, scores the result. |
| [4] output_fn β Output pin. Sends the JSON response back. |
| |
| Your classifier lives in classifier.py (same folder). |
| No GPU, no heavy ML libraries needed β pure Python logic. |
| """ |
|
|
| import json |
| import os |
| import uuid |
| import logging |
| from datetime import datetime, timezone |
|
|
| |
| from classifier import classify, decode, extract_features |
|
|
| |
| logger = logging.getLogger(__name__) |
| logger.setLevel(logging.INFO) |
|
|
| |
| |
| |
| _REWARDS = { |
| "EXACT": 1.0, |
| "PARTIAL_1": 0.2, |
| "PARTIAL_2": 0.1, |
| "SECURITY_MISS": -2.0, |
| "WRONG": 0.0, |
| } |
|
|
| |
| _SLA = { |
| 0: {"priority": "P3", "respond_within_minutes": 1440}, |
| 1: {"priority": "P2", "respond_within_minutes": 240}, |
| 2: {"priority": "P1", "respond_within_minutes": 15}, |
| } |
|
|
|
|
| |
| |
| |
|
|
| def _score_match(predicted: tuple, ground_truth: dict) -> dict: |
| """ |
| Compare the 3 predicted dimensions against the known correct answer. |
| |
| Only called when the request includes a "ground_truth" field. |
| Useful for: |
| - Offline evaluation / batch testing |
| - Logging accuracy metrics to CloudWatch |
| |
| Returns a dict with: |
| status β one of EXACT / PARTIAL_1 / PARTIAL_2 / SECURITY_MISS / WRONG |
| reward β float score matching your RL reward function |
| wrong_fields β list of dimension names that were predicted incorrectly |
| """ |
| p_urgency, p_routing, p_resolution = predicted |
|
|
| g_urgency = int(ground_truth["urgency"]) |
| g_routing = int(ground_truth["routing"]) |
| g_resolution = int(ground_truth["resolution"]) |
|
|
| |
| correct = { |
| "urgency": p_urgency == g_urgency, |
| "routing": p_routing == g_routing, |
| "resolution": p_resolution == g_resolution, |
| } |
| wrong = [dim for dim, ok in correct.items() if not ok] |
|
|
| |
| |
| if g_urgency == 2 and p_urgency != 2: |
| status = "SECURITY_MISS" |
|
|
| |
| elif not wrong: |
| status = "EXACT" |
|
|
| |
| elif correct["urgency"] and len(wrong) == 1: |
| status = "PARTIAL_1" |
|
|
| |
| elif correct["urgency"] and len(wrong) == 2: |
| status = "PARTIAL_2" |
|
|
| |
| else: |
| status = "WRONG" |
|
|
| logger.info( |
| "MATCH_EVAL | status=%s reward=%.1f wrong_fields=%s", |
| status, _REWARDS[status], wrong |
| ) |
|
|
| return { |
| "status": status, |
| "reward": _REWARDS[status], |
| "correct_dims": correct, |
| "wrong_fields": wrong, |
| } |
|
|
|
|
| |
| |
| |
|
|
| def model_fn(model_dir: str) -> dict: |
| """ |
| SageMaker calls this once when the container boots up. |
| model_dir is the folder where SageMaker unpacks your model.tar.gz. |
| |
| For a rule-based classifier there are no weights to load. |
| We just return a config dict that predict_fn will use. |
| """ |
| logger.info("model_fn | model_dir=%s", model_dir) |
|
|
| |
| |
| config_path = os.path.join(model_dir, "config.json") |
| config = {} |
| if os.path.exists(config_path): |
| with open(config_path) as f: |
| config = json.load(f) |
| logger.info("Config loaded: %s", config) |
|
|
| model = { |
| "version": config.get("version", "1.0.0"), |
| "sla": config.get("sla", _SLA), |
| |
| "endpoint_name": os.environ.get("SAGEMAKER_ENDPOINT_NAME", "local"), |
| } |
|
|
| logger.info("Model ready | version=%s endpoint=%s", |
| model["version"], model["endpoint_name"]) |
| return model |
|
|
|
|
| |
| |
| |
|
|
| def input_fn(request_body: str | bytes, content_type: str) -> dict: |
| """ |
| Converts the raw bytes from the HTTP POST body into a Python dict. |
| |
| Accepted request formats: |
| |
| Format A β JSON with raw email text (most common): |
| { |
| "subject": "Your account was hacked", |
| "body": "We detected unauthorized access..." |
| } |
| |
| Format B β JSON with pre-extracted features (faster, skips NLP): |
| { |
| "keywords": ["hacked", "password"], |
| "sentiment": "negative", |
| "context": "security" |
| } |
| |
| Format C β Add ground_truth to either format above for accuracy scoring: |
| { |
| "subject": "...", |
| "body": "...", |
| "ground_truth": {"urgency": 2, "routing": 1, "resolution": 2} |
| } |
| """ |
| logger.info("input_fn | content_type=%s", content_type) |
|
|
| ct = content_type.lower().split(";")[0].strip() |
|
|
| if ct == "application/json": |
| if isinstance(request_body, bytes): |
| request_body = request_body.decode("utf-8") |
| payload = json.loads(request_body) |
|
|
| elif ct == "text/plain": |
| |
| text = request_body.decode("utf-8") if isinstance(request_body, bytes) else request_body |
| payload = {"subject": "", "body": text} |
|
|
| else: |
| raise ValueError( |
| f"Unsupported content type: '{content_type}'. " |
| "Send 'application/json' or 'text/plain'." |
| ) |
|
|
| |
| if not any([payload.get("subject"), payload.get("body"), |
| payload.get("keywords"), payload.get("context")]): |
| raise ValueError( |
| "Request must include 'subject', 'body', 'keywords', or 'context'." |
| ) |
|
|
| return payload |
|
|
|
|
| |
| |
| |
|
|
| def predict_fn(data: dict, model: dict) -> dict: |
| """ |
| The main classification step. Runs on every request. |
| |
| Step 1: Extract features from raw text (or use pre-supplied features) |
| Step 2: Classify β 3 integer codes (urgency, routing, resolution) |
| Step 3: Decode codes β human labels ("Security Breach", "Escalate", ...) |
| Step 4: Score against ground_truth (only if ground_truth is in request) |
| Step 5: Return everything as a dict (output_fn will format it as JSON) |
| """ |
| logger.info("predict_fn | keys=%s", list(data.keys())) |
|
|
| |
| |
| if data.get("context"): |
| features = { |
| "keywords": data.get("keywords", []), |
| "sentiment": data.get("sentiment", "neutral"), |
| "context": data["context"], |
| } |
| |
| else: |
| features = extract_features( |
| subject=data.get("subject", ""), |
| body=data.get("body", ""), |
| ) |
|
|
| |
| urgency, routing, resolution = classify(features) |
|
|
| |
| labels = decode(urgency, routing, resolution) |
|
|
| logger.info( |
| "CLASSIFIED | category=%s dept=%s action=%s | context=%s keywords=%s", |
| labels["urgency"], labels["routing"], labels["resolution"], |
| features["context"], features["keywords"], |
| ) |
|
|
| |
| ground_truth = data.get("ground_truth") |
| if ground_truth: |
| match = _score_match((urgency, routing, resolution), ground_truth) |
| else: |
| |
| match = {"status": "UNVERIFIED", "reward": None, |
| "correct_dims": {}, "wrong_fields": []} |
|
|
| |
| return { |
| "urgency_code": urgency, |
| "routing_code": routing, |
| "resolution_code": resolution, |
| "labels": labels, |
| "features": features, |
| "match": match, |
| "sla": model["sla"][urgency], |
| "endpoint": model["endpoint_name"], |
| } |
|
|
|
|
| |
| |
| |
|
|
| def output_fn(prediction: dict, accept: str) -> tuple[str, str]: |
| """ |
| Converts the prediction dict into the final HTTP response body. |
| |
| Default response format: application/json |
| Optional CSV format: text/csv (useful for batch jobs writing to S3) |
| |
| JSON response shape: |
| { |
| "request_id": "uuid", |
| "timestamp": "2024-01-15T10:30:00Z", |
| |
| "triage": { |
| "category": "Security Breach", β urgency label |
| "department": "Tech Support", β routing label |
| "action": "Escalate" β resolution label |
| }, |
| |
| "codes": { |
| "urgency": 2, "routing": 1, "resolution": 2 |
| }, |
| |
| "match_result": { |
| "status": "EXACT", β or PARTIAL_1 / PARTIAL_2 / SECURITY_MISS / WRONG |
| "reward": 1.0, β RL reward score |
| "wrong_fields": [] β which dimensions were wrong |
| }, |
| |
| "sla": { |
| "priority": "P1", |
| "respond_within_minutes": 15 |
| } |
| } |
| """ |
| accept_type = (accept or "application/json").lower().split(";")[0].strip() |
|
|
| response = { |
| "request_id": str(uuid.uuid4()), |
| "timestamp": datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ"), |
| "triage": { |
| "category": prediction["labels"]["urgency"], |
| "department": prediction["labels"]["routing"], |
| "action": prediction["labels"]["resolution"], |
| }, |
| "codes": { |
| "urgency": prediction["urgency_code"], |
| "routing": prediction["routing_code"], |
| "resolution": prediction["resolution_code"], |
| }, |
| "features": { |
| "keywords": prediction["features"]["keywords"], |
| "sentiment": prediction["features"]["sentiment"], |
| "context": prediction["features"]["context"], |
| }, |
| "match_result": { |
| "status": prediction["match"]["status"], |
| "reward": prediction["match"]["reward"], |
| "wrong_fields": prediction["match"]["wrong_fields"], |
| }, |
| "sla": prediction["sla"], |
| } |
|
|
| |
| if accept_type == "text/csv": |
| row = ",".join([ |
| response["request_id"], |
| response["triage"]["category"], |
| response["triage"]["department"], |
| response["triage"]["action"], |
| str(response["codes"]["urgency"]), |
| str(response["codes"]["routing"]), |
| str(response["codes"]["resolution"]), |
| str(response["match_result"]["status"]), |
| str(response["match_result"]["reward"] or ""), |
| response["sla"]["priority"], |
| ]) |
| return row, "text/csv" |
|
|
| return json.dumps(response, ensure_ascii=False), "application/json" |
|
|