#!/usr/bin/env python3 import argparse import io from pathlib import Path from typing import Union import torchvision import uvicorn from fastapi import FastAPI, File, UploadFile from fastapi.responses import StreamingResponse from PIL import Image from ...datasets.transforms import resize_256_224, to_color, to_pil, to_ts from ...models.attack import AIMAttack def init_attack(ckpt: Union[str, Path] = None): attack = AIMAttack(device='cpu') if ckpt: attack.load_ckpt(ckpt) attack.set_mode('eval') attack_preproc = torchvision.transforms.Compose(resize_256_224() + to_color() + to_ts()) return attack, attack_preproc def main(): parser = argparse.ArgumentParser(description='AIM Attack API') parser.add_argument('--host', type=str, default='0.0.0.0', help='host') parser.add_argument('--port', type=int, default=8000, help='port') parser.add_argument('--ckpt', type=str, default=None, help='path to the checkpoint') args = parser.parse_args() attack, attack_preproc = init_attack(args.ckpt) app = FastAPI() @app.get('/attack/aim/') async def aim_attack(x_nat: UploadFile = File(...), x_guid: UploadFile = File(...)): io_x_nat = await x_nat.read() io_x_guid = await x_guid.read() pil_x_nat = Image.open(io.BytesIO(io_x_nat)) pil_x_guid = Image.open(io.BytesIO(io_x_guid)) ts_x_nat = attack_preproc(pil_x_nat).unsqueeze(0) ts_x_guid = attack_preproc(pil_x_guid).unsqueeze(0) ts_x_adv = attack(ts_x_nat, ts_x_guid) pil_x_adv = torchvision.transforms.Compose(to_pil())(ts_x_adv[0]) img_byte_array = io.BytesIO() pil_x_adv.save(img_byte_array, format='PNG') img_byte_array.seek(0) return StreamingResponse(img_byte_array, media_type='image/png') uvicorn.run(app, host=args.host, port=args.port)