|
|
|
|
|
import argparse
|
|
|
import io
|
|
|
from pathlib import Path
|
|
|
from typing import Union
|
|
|
|
|
|
import torchvision
|
|
|
import uvicorn
|
|
|
from fastapi import FastAPI, File, UploadFile
|
|
|
from fastapi.responses import StreamingResponse
|
|
|
from PIL import Image
|
|
|
|
|
|
from ...datasets.transforms import resize_256_224, to_color, to_pil, to_ts
|
|
|
from ...models.attack import AIMAttack
|
|
|
|
|
|
|
|
|
def init_attack(ckpt: Union[str, Path] = None):
|
|
|
attack = AIMAttack(device='cpu')
|
|
|
if ckpt:
|
|
|
attack.load_ckpt(ckpt)
|
|
|
attack.set_mode('eval')
|
|
|
attack_preproc = torchvision.transforms.Compose(resize_256_224() +
|
|
|
to_color() + to_ts())
|
|
|
return attack, attack_preproc
|
|
|
|
|
|
|
|
|
def main():
|
|
|
parser = argparse.ArgumentParser(description='AIM Attack API')
|
|
|
parser.add_argument('--host', type=str, default='0.0.0.0', help='host')
|
|
|
parser.add_argument('--port', type=int, default=8000, help='port')
|
|
|
parser.add_argument('--ckpt',
|
|
|
type=str,
|
|
|
default=None,
|
|
|
help='path to the checkpoint')
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
attack, attack_preproc = init_attack(args.ckpt)
|
|
|
app = FastAPI()
|
|
|
|
|
|
@app.get('/attack/aim/')
|
|
|
async def aim_attack(x_nat: UploadFile = File(...),
|
|
|
x_guid: UploadFile = File(...)):
|
|
|
io_x_nat = await x_nat.read()
|
|
|
io_x_guid = await x_guid.read()
|
|
|
pil_x_nat = Image.open(io.BytesIO(io_x_nat))
|
|
|
pil_x_guid = Image.open(io.BytesIO(io_x_guid))
|
|
|
ts_x_nat = attack_preproc(pil_x_nat).unsqueeze(0)
|
|
|
ts_x_guid = attack_preproc(pil_x_guid).unsqueeze(0)
|
|
|
ts_x_adv = attack(ts_x_nat, ts_x_guid)
|
|
|
pil_x_adv = torchvision.transforms.Compose(to_pil())(ts_x_adv[0])
|
|
|
|
|
|
img_byte_array = io.BytesIO()
|
|
|
pil_x_adv.save(img_byte_array, format='PNG')
|
|
|
img_byte_array.seek(0)
|
|
|
|
|
|
return StreamingResponse(img_byte_array, media_type='image/png')
|
|
|
|
|
|
uvicorn.run(app, host=args.host, port=args.port)
|
|
|
|