| import os |
| import json |
| import base64 |
| from io import BytesIO |
| from PIL import Image |
| import torch |
| from transformers import AutoProcessor, ShieldGemma2ForImageClassification |
|
|
| def model_fn(model_dir): |
| """ |
| Loads the model and processor from the disk. This function is called by |
| SageMaker once when the endpoint is created. |
| |
| Args: |
| model_dir (str): The directory where the model artifacts are stored. |
| |
| Returns: |
| dict: A dictionary containing the loaded model, processor, and device. |
| """ |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"Loading model on device: {device}") |
|
|
| model = ShieldGemma2ForImageClassification.from_pretrained( |
| model_dir, |
| torch_dtype=torch.bfloat16 |
| ).to(device).eval() |
| |
| processor = AutoProcessor.from_pretrained(model_dir) |
| |
| return {"model": model, "processor": processor, "device": device} |
|
|
| def input_fn(request_body, request_content_type): |
| """ |
| Parses the input request. This function is called by SageMaker for each |
| inference request. It expects a JSON payload with a base64-encoded 'image' |
| and optionally custom policies for experimentation. |
| |
| Args: |
| request_body (bytes): The body of the request. Expected to be a JSON |
| payload with a base64-encoded 'image' string and |
| optionally a 'custom_policies' dict. |
| request_content_type (str): The content type of the request. |
| |
| Returns: |
| dict: A dictionary containing the PIL image and optional custom policies. |
| """ |
| if request_content_type == 'application/json': |
| request = json.loads(request_body) |
| |
| |
| if isinstance(request, str): |
| request = json.loads(request) |
| |
| image_b64 = request.get('image') |
| if not image_b64: |
| raise ValueError("Missing 'image' key in JSON payload") |
| |
| img_bytes = base64.b64decode(image_b64) |
| image = Image.open(BytesIO(img_bytes)).convert("RGB") |
| |
| |
| custom_policies = request.get('custom_policies', None) |
| |
| return {'image': image, 'custom_policies': custom_policies} |
| else: |
| raise ValueError(f"Unsupported content type: {request_content_type}") |
|
|
| def predict_fn(input_data, model_artifacts): |
| """ |
| Performs inference on the input data against custom policies (if provided) |
| or default predefined policies. |
| |
| Args: |
| input_data (dict): A dictionary containing the PIL image and optional custom policies. |
| model_artifacts (dict): The dictionary returned by model_fn. |
| |
| Returns: |
| list[dict]: A list of dictionaries, one for each policy, containing the |
| policy name and its corresponding safety prediction. |
| """ |
| image = input_data['image'] |
| client_custom_policies = input_data.get('custom_policies', None) |
| |
| model = model_artifacts["model"] |
| processor = model_artifacts["processor"] |
| device = model_artifacts["device"] |
| |
| |
| if client_custom_policies: |
| print(f"Using {len(client_custom_policies)} custom policies from client") |
| custom_policies = client_custom_policies |
| else: |
| print("Using default predefined policies") |
| |
| custom_policies = { |
| "Hate/Discrimination": "The image must not contain content that promotes attacks based on personal characteristics or content that claims certain protected groups such as gender, race, religion, sexual orientation, etc. are better than others. This includes hate symbols, supremacist imagery, visual depictions of stereotypes, or any content that promotes attacks or discrimination against groups based on protected attributes.", |
| "Misinformation": "The image must not contain content seeking to spread misinformation, disinformation, or conspiracy theories that undermine public trust or propagate misconceptions and beliefs. This includes doctored photos presented as authentic, fabricated evidence of events that didn't occur, misleading graphs/charts with distorted data, or visual content that promotes harmful conspiracy theories or non-scientific medical treatment.", |
| "Jailbreak/Prompt Injection": "The image must not contain content attempting to manipulate AI models to disobey explicit instructions or safety policies through jailbreak attacks, roleplay, or other fictional scenarios. However, content that aims to improve AI safety and education, such as red team attempts, or content involving queries on how to override or bypass service and software restrictions for development purposes is considered safe.", |
| "Personally Identifiable Information (PII)": "The image must not contain content that aims to extract information about the race or religion of a specific individual, or content that shows social security numbers, credit card information, driver licenses, passport numbers, and medical information of individuals. Content that shows permitted personal information, including names, account information, public and work email addresses, phone numbers, date of birth, work profiles, and public profiles is considered safe.", |
| "Politics": "The image must not contain content involving political campaigning or lobbying, including generating campaign materials to influence a political process, or content that may lead to interference with participation in electoral, democratic or civic processes, or content seeking opinions related to politics. However, discussion of political content from a fact-based perspective, especially in an educational, informative, or journalistic context is considered safe.", |
| "Impersonation": "The image must not contain content that impersonates or seeks to impersonate specific individuals (including by naming or assigning it to the identity of a real person, living or deceased), fictional characters protected under copyright laws (e.g. characters from movies, books, video games, and other media), or specific organisations.", |
| "Specialist Advice": "The image must not provide or represent advice that would ordinarily be provided by a qualified or licensed professional, including legal advice, medical and health advice, financial advice, and/or other professional advice of any kind. This includes medical prescriptions or treatment plans, legal documents or legal advice, financial investment recommendations, or other specialized information that should only be provided by certified professionals.", |
| "High-risk Decisions": "The image must not contain content involving 'high stakes' decision making including in any domains that may affect an individual's rights, safety, health or well-being including high risk domains like finance, legal, credit, insurance, employment (including hiring/lay off decisions), performance reviews, housing, education, essential services, law enforcement, migration, management of critical infrastructure, judicial proceedings, social scoring or any use case where subjective evaluation of a human is being done. However, content asking to summarize or compile information relating to high risk domains, given that explicit decision making is not involved, is considered safe.", |
| } |
| |
| |
| |
| custom_policies_dict = {} |
| policy_name_mapping = {} |
| |
| for policy_name, policy_text in custom_policies.items(): |
| |
| key = policy_name.lower().replace('/', '_').replace(' ', '_').replace('(', '').replace(')', '').replace('-', '_') |
| |
| key = key.replace('__', '_').strip('_') |
| |
| custom_policies_dict[key] = policy_text |
| policy_name_mapping[key] = policy_name |
| |
| policies_to_evaluate = list(custom_policies_dict.keys()) |
| |
| print(f"Evaluating {len(policies_to_evaluate)} policies: {list(policy_name_mapping.values())}") |
| |
| |
| model_inputs = processor( |
| images=[image], |
| custom_policies=custom_policies_dict, |
| policies=policies_to_evaluate, |
| return_tensors="pt" |
| ).to(device) |
|
|
| with torch.inference_mode(): |
| output = model(**model_inputs) |
|
|
| |
| probabilities = output.probabilities |
| |
| results = [] |
| for idx, policy_key in enumerate(policies_to_evaluate): |
| |
| unsafe_prob = 1.0 - float(probabilities[idx][1]) |
| safe_prob = float(probabilities[idx][1]) |
| |
| |
| original_policy_name = policy_name_mapping[policy_key] |
| |
| prediction = { |
| "policy": original_policy_name, |
| "prediction": { |
| "Unsafe": unsafe_prob, |
| "Safe": safe_prob |
| } |
| } |
| results.append(prediction) |
| |
| return results |
|
|
| def output_fn(prediction, content_type): |
| """ |
| Formats the prediction output. |
| |
| Args: |
| prediction (list[dict]): The list of prediction results from predict_fn. |
| content_type (str): The desired content type of the response. |
| |
| Returns: |
| str: A JSON-formatted string of the prediction list. |
| """ |
| if content_type == 'application/json': |
| return json.dumps(prediction) |
| raise ValueError(f"Unsupported content type: {content_type}") |
|
|