image-moderation-v1 / code /inference.py
TomLEE2026's picture
Upload 14 files
5cda760 verified
Raw
History Blame Contribute Delete
9.97 kB
import os
import json
import base64
from io import BytesIO
from PIL import Image
import torch
from transformers import AutoProcessor, ShieldGemma2ForImageClassification
def model_fn(model_dir):
"""
Loads the model and processor from the disk. This function is called by
SageMaker once when the endpoint is created.
Args:
model_dir (str): The directory where the model artifacts are stored.
Returns:
dict: A dictionary containing the loaded model, processor, and device.
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading model on device: {device}")
model = ShieldGemma2ForImageClassification.from_pretrained(
model_dir,
torch_dtype=torch.bfloat16
).to(device).eval()
processor = AutoProcessor.from_pretrained(model_dir)
return {"model": model, "processor": processor, "device": device}
def input_fn(request_body, request_content_type):
"""
Parses the input request. This function is called by SageMaker for each
inference request. It expects a JSON payload with a base64-encoded 'image'
and optionally custom policies for experimentation.
Args:
request_body (bytes): The body of the request. Expected to be a JSON
payload with a base64-encoded 'image' string and
optionally a 'custom_policies' dict.
request_content_type (str): The content type of the request.
Returns:
dict: A dictionary containing the PIL image and optional custom policies.
"""
if request_content_type == 'application/json':
request = json.loads(request_body)
# Handle cases where the payload might be double-encoded as a JSON string
if isinstance(request, str):
request = json.loads(request)
image_b64 = request.get('image')
if not image_b64:
raise ValueError("Missing 'image' key in JSON payload")
img_bytes = base64.b64decode(image_b64)
image = Image.open(BytesIO(img_bytes)).convert("RGB")
# Optional custom policies for experimentation
custom_policies = request.get('custom_policies', None)
return {'image': image, 'custom_policies': custom_policies}
else:
raise ValueError(f"Unsupported content type: {request_content_type}")
def predict_fn(input_data, model_artifacts):
"""
Performs inference on the input data against custom policies (if provided)
or default predefined policies.
Args:
input_data (dict): A dictionary containing the PIL image and optional custom policies.
model_artifacts (dict): The dictionary returned by model_fn.
Returns:
list[dict]: A list of dictionaries, one for each policy, containing the
policy name and its corresponding safety prediction.
"""
image = input_data['image']
client_custom_policies = input_data.get('custom_policies', None)
model = model_artifacts["model"]
processor = model_artifacts["processor"]
device = model_artifacts["device"]
# Use client-provided policies if available, otherwise use defaults
if client_custom_policies:
print(f"Using {len(client_custom_policies)} custom policies from client")
custom_policies = client_custom_policies
else:
print("Using default predefined policies")
# Default custom policies (same as original notebook)
custom_policies = {
"Hate/Discrimination": "The image must not contain content that promotes attacks based on personal characteristics or content that claims certain protected groups such as gender, race, religion, sexual orientation, etc. are better than others. This includes hate symbols, supremacist imagery, visual depictions of stereotypes, or any content that promotes attacks or discrimination against groups based on protected attributes.",
"Misinformation": "The image must not contain content seeking to spread misinformation, disinformation, or conspiracy theories that undermine public trust or propagate misconceptions and beliefs. This includes doctored photos presented as authentic, fabricated evidence of events that didn't occur, misleading graphs/charts with distorted data, or visual content that promotes harmful conspiracy theories or non-scientific medical treatment.",
"Jailbreak/Prompt Injection": "The image must not contain content attempting to manipulate AI models to disobey explicit instructions or safety policies through jailbreak attacks, roleplay, or other fictional scenarios. However, content that aims to improve AI safety and education, such as red team attempts, or content involving queries on how to override or bypass service and software restrictions for development purposes is considered safe.",
"Personally Identifiable Information (PII)": "The image must not contain content that aims to extract information about the race or religion of a specific individual, or content that shows social security numbers, credit card information, driver licenses, passport numbers, and medical information of individuals. Content that shows permitted personal information, including names, account information, public and work email addresses, phone numbers, date of birth, work profiles, and public profiles is considered safe.",
"Politics": "The image must not contain content involving political campaigning or lobbying, including generating campaign materials to influence a political process, or content that may lead to interference with participation in electoral, democratic or civic processes, or content seeking opinions related to politics. However, discussion of political content from a fact-based perspective, especially in an educational, informative, or journalistic context is considered safe.",
"Impersonation": "The image must not contain content that impersonates or seeks to impersonate specific individuals (including by naming or assigning it to the identity of a real person, living or deceased), fictional characters protected under copyright laws (e.g. characters from movies, books, video games, and other media), or specific organisations.",
"Specialist Advice": "The image must not provide or represent advice that would ordinarily be provided by a qualified or licensed professional, including legal advice, medical and health advice, financial advice, and/or other professional advice of any kind. This includes medical prescriptions or treatment plans, legal documents or legal advice, financial investment recommendations, or other specialized information that should only be provided by certified professionals.",
"High-risk Decisions": "The image must not contain content involving 'high stakes' decision making including in any domains that may affect an individual's rights, safety, health or well-being including high risk domains like finance, legal, credit, insurance, employment (including hiring/lay off decisions), performance reviews, housing, education, essential services, law enforcement, migration, management of critical infrastructure, judicial proceedings, social scoring or any use case where subjective evaluation of a human is being done. However, content asking to summarize or compile information relating to high risk domains, given that explicit decision making is not involved, is considered safe.",
}
# Create the custom_policies_dict with keys for the processor
# Convert policy names to lowercase keys with underscores
custom_policies_dict = {}
policy_name_mapping = {} # To map keys back to original names
for policy_name, policy_text in custom_policies.items():
# Create a key by converting to lowercase and replacing special chars
key = policy_name.lower().replace('/', '_').replace(' ', '_').replace('(', '').replace(')', '').replace('-', '_')
# Remove any double underscores
key = key.replace('__', '_').strip('_')
custom_policies_dict[key] = policy_text
policy_name_mapping[key] = policy_name
policies_to_evaluate = list(custom_policies_dict.keys())
print(f"Evaluating {len(policies_to_evaluate)} policies: {list(policy_name_mapping.values())}")
# Use the processor exactly like in the original notebook
model_inputs = processor(
images=[image],
custom_policies=custom_policies_dict,
policies=policies_to_evaluate, # Which policies to evaluate
return_tensors="pt"
).to(device)
with torch.inference_mode():
output = model(**model_inputs)
# output.probabilities will have one row per policy
probabilities = output.probabilities # Shape: [num_policies, 2]
results = []
for idx, policy_key in enumerate(policies_to_evaluate):
# Calculate unsafe probability (1 - safe probability)
unsafe_prob = 1.0 - float(probabilities[idx][1])
safe_prob = float(probabilities[idx][1])
# Use the original policy name for reporting
original_policy_name = policy_name_mapping[policy_key]
prediction = {
"policy": original_policy_name,
"prediction": {
"Unsafe": unsafe_prob,
"Safe": safe_prob
}
}
results.append(prediction)
return results
def output_fn(prediction, content_type):
"""
Formats the prediction output.
Args:
prediction (list[dict]): The list of prediction results from predict_fn.
content_type (str): The desired content type of the response.
Returns:
str: A JSON-formatted string of the prediction list.
"""
if content_type == 'application/json':
return json.dumps(prediction)
raise ValueError(f"Unsupported content type: {content_type}")