File size: 2,074 Bytes
998bb30 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
#!/usr/bin/env python3
import argparse
import io
from pathlib import Path
from typing import Union
import torchvision
import uvicorn
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import StreamingResponse
from PIL import Image
from ...datasets.transforms import resize_256_224, to_color, to_pil, to_ts
from ...models.attack import AIMAttack
def init_attack(ckpt: Union[str, Path] = None):
attack = AIMAttack(device='cpu')
if ckpt:
attack.load_ckpt(ckpt)
attack.set_mode('eval')
attack_preproc = torchvision.transforms.Compose(resize_256_224() +
to_color() + to_ts())
return attack, attack_preproc
def main():
parser = argparse.ArgumentParser(description='AIM Attack API')
parser.add_argument('--host', type=str, default='0.0.0.0', help='host')
parser.add_argument('--port', type=int, default=8000, help='port')
parser.add_argument('--ckpt',
type=str,
default=None,
help='path to the checkpoint')
args = parser.parse_args()
attack, attack_preproc = init_attack(args.ckpt)
app = FastAPI()
@app.get('/attack/aim/')
async def aim_attack(x_nat: UploadFile = File(...),
x_guid: UploadFile = File(...)):
io_x_nat = await x_nat.read()
io_x_guid = await x_guid.read()
pil_x_nat = Image.open(io.BytesIO(io_x_nat))
pil_x_guid = Image.open(io.BytesIO(io_x_guid))
ts_x_nat = attack_preproc(pil_x_nat).unsqueeze(0)
ts_x_guid = attack_preproc(pil_x_guid).unsqueeze(0)
ts_x_adv = attack(ts_x_nat, ts_x_guid)
pil_x_adv = torchvision.transforms.Compose(to_pil())(ts_x_adv[0])
img_byte_array = io.BytesIO()
pil_x_adv.save(img_byte_array, format='PNG')
img_byte_array.seek(0)
return StreamingResponse(img_byte_array, media_type='image/png')
uvicorn.run(app, host=args.host, port=args.port)
|