import os import json import base64 from io import BytesIO from PIL import Image import torch from transformers import AutoProcessor, ShieldGemma2ForImageClassification def model_fn(model_dir): """ Loads the model and processor from the disk. This function is called by SageMaker once when the endpoint is created. Args: model_dir (str): The directory where the model artifacts are stored. Returns: dict: A dictionary containing the loaded model, processor, and device. """ device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Loading model on device: {device}") model = ShieldGemma2ForImageClassification.from_pretrained( model_dir, torch_dtype=torch.bfloat16 ).to(device).eval() processor = AutoProcessor.from_pretrained(model_dir) return {"model": model, "processor": processor, "device": device} def input_fn(request_body, request_content_type): """ Parses the input request. This function is called by SageMaker for each inference request. It expects a JSON payload with a base64-encoded 'image' and optionally custom policies for experimentation. Args: request_body (bytes): The body of the request. Expected to be a JSON payload with a base64-encoded 'image' string and optionally a 'custom_policies' dict. request_content_type (str): The content type of the request. Returns: dict: A dictionary containing the PIL image and optional custom policies. """ if request_content_type == 'application/json': request = json.loads(request_body) # Handle cases where the payload might be double-encoded as a JSON string if isinstance(request, str): request = json.loads(request) image_b64 = request.get('image') if not image_b64: raise ValueError("Missing 'image' key in JSON payload") img_bytes = base64.b64decode(image_b64) image = Image.open(BytesIO(img_bytes)).convert("RGB") # Optional custom policies for experimentation custom_policies = request.get('custom_policies', None) return {'image': image, 'custom_policies': custom_policies} else: raise ValueError(f"Unsupported content type: {request_content_type}") def predict_fn(input_data, model_artifacts): """ Performs inference on the input data against custom policies (if provided) or default predefined policies. Args: input_data (dict): A dictionary containing the PIL image and optional custom policies. model_artifacts (dict): The dictionary returned by model_fn. Returns: list[dict]: A list of dictionaries, one for each policy, containing the policy name and its corresponding safety prediction. """ image = input_data['image'] client_custom_policies = input_data.get('custom_policies', None) model = model_artifacts["model"] processor = model_artifacts["processor"] device = model_artifacts["device"] # Use client-provided policies if available, otherwise use defaults if client_custom_policies: print(f"Using {len(client_custom_policies)} custom policies from client") custom_policies = client_custom_policies else: print("Using default predefined policies") # Default custom policies (same as original notebook) custom_policies = { "Hate/Discrimination": "The image must not contain content that promotes attacks based on personal characteristics or content that claims certain protected groups such as gender, race, religion, sexual orientation, etc. are better than others. This includes hate symbols, supremacist imagery, visual depictions of stereotypes, or any content that promotes attacks or discrimination against groups based on protected attributes.", "Misinformation": "The image must not contain content seeking to spread misinformation, disinformation, or conspiracy theories that undermine public trust or propagate misconceptions and beliefs. This includes doctored photos presented as authentic, fabricated evidence of events that didn't occur, misleading graphs/charts with distorted data, or visual content that promotes harmful conspiracy theories or non-scientific medical treatment.", "Jailbreak/Prompt Injection": "The image must not contain content attempting to manipulate AI models to disobey explicit instructions or safety policies through jailbreak attacks, roleplay, or other fictional scenarios. However, content that aims to improve AI safety and education, such as red team attempts, or content involving queries on how to override or bypass service and software restrictions for development purposes is considered safe.", "Personally Identifiable Information (PII)": "The image must not contain content that aims to extract information about the race or religion of a specific individual, or content that shows social security numbers, credit card information, driver licenses, passport numbers, and medical information of individuals. Content that shows permitted personal information, including names, account information, public and work email addresses, phone numbers, date of birth, work profiles, and public profiles is considered safe.", "Politics": "The image must not contain content involving political campaigning or lobbying, including generating campaign materials to influence a political process, or content that may lead to interference with participation in electoral, democratic or civic processes, or content seeking opinions related to politics. However, discussion of political content from a fact-based perspective, especially in an educational, informative, or journalistic context is considered safe.", "Impersonation": "The image must not contain content that impersonates or seeks to impersonate specific individuals (including by naming or assigning it to the identity of a real person, living or deceased), fictional characters protected under copyright laws (e.g. characters from movies, books, video games, and other media), or specific organisations.", "Specialist Advice": "The image must not provide or represent advice that would ordinarily be provided by a qualified or licensed professional, including legal advice, medical and health advice, financial advice, and/or other professional advice of any kind. This includes medical prescriptions or treatment plans, legal documents or legal advice, financial investment recommendations, or other specialized information that should only be provided by certified professionals.", "High-risk Decisions": "The image must not contain content involving 'high stakes' decision making including in any domains that may affect an individual's rights, safety, health or well-being including high risk domains like finance, legal, credit, insurance, employment (including hiring/lay off decisions), performance reviews, housing, education, essential services, law enforcement, migration, management of critical infrastructure, judicial proceedings, social scoring or any use case where subjective evaluation of a human is being done. However, content asking to summarize or compile information relating to high risk domains, given that explicit decision making is not involved, is considered safe.", } # Create the custom_policies_dict with keys for the processor # Convert policy names to lowercase keys with underscores custom_policies_dict = {} policy_name_mapping = {} # To map keys back to original names for policy_name, policy_text in custom_policies.items(): # Create a key by converting to lowercase and replacing special chars key = policy_name.lower().replace('/', '_').replace(' ', '_').replace('(', '').replace(')', '').replace('-', '_') # Remove any double underscores key = key.replace('__', '_').strip('_') custom_policies_dict[key] = policy_text policy_name_mapping[key] = policy_name policies_to_evaluate = list(custom_policies_dict.keys()) print(f"Evaluating {len(policies_to_evaluate)} policies: {list(policy_name_mapping.values())}") # Use the processor exactly like in the original notebook model_inputs = processor( images=[image], custom_policies=custom_policies_dict, policies=policies_to_evaluate, # Which policies to evaluate return_tensors="pt" ).to(device) with torch.inference_mode(): output = model(**model_inputs) # output.probabilities will have one row per policy probabilities = output.probabilities # Shape: [num_policies, 2] results = [] for idx, policy_key in enumerate(policies_to_evaluate): # Calculate unsafe probability (1 - safe probability) unsafe_prob = 1.0 - float(probabilities[idx][1]) safe_prob = float(probabilities[idx][1]) # Use the original policy name for reporting original_policy_name = policy_name_mapping[policy_key] prediction = { "policy": original_policy_name, "prediction": { "Unsafe": unsafe_prob, "Safe": safe_prob } } results.append(prediction) return results def output_fn(prediction, content_type): """ Formats the prediction output. Args: prediction (list[dict]): The list of prediction results from predict_fn. content_type (str): The desired content type of the response. Returns: str: A JSON-formatted string of the prediction list. """ if content_type == 'application/json': return json.dumps(prediction) raise ValueError(f"Unsupported content type: {content_type}")