File size: 2,074 Bytes
998bb30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#!/usr/bin/env python3
import argparse
import io
from pathlib import Path
from typing import Union

import torchvision
import uvicorn
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import StreamingResponse
from PIL import Image

from ...datasets.transforms import resize_256_224, to_color, to_pil, to_ts
from ...models.attack import AIMAttack


def init_attack(ckpt: Union[str, Path] = None):
    attack = AIMAttack(device='cpu')
    if ckpt:
        attack.load_ckpt(ckpt)
    attack.set_mode('eval')
    attack_preproc = torchvision.transforms.Compose(resize_256_224() +
                                                    to_color() + to_ts())
    return attack, attack_preproc


def main():
    parser = argparse.ArgumentParser(description='AIM Attack API')
    parser.add_argument('--host', type=str, default='0.0.0.0', help='host')
    parser.add_argument('--port', type=int, default=8000, help='port')
    parser.add_argument('--ckpt',
                        type=str,
                        default=None,
                        help='path to the checkpoint')
    args = parser.parse_args()

    attack, attack_preproc = init_attack(args.ckpt)
    app = FastAPI()

    @app.get('/attack/aim/')
    async def aim_attack(x_nat: UploadFile = File(...),

                         x_guid: UploadFile = File(...)):
        io_x_nat = await x_nat.read()
        io_x_guid = await x_guid.read()
        pil_x_nat = Image.open(io.BytesIO(io_x_nat))
        pil_x_guid = Image.open(io.BytesIO(io_x_guid))
        ts_x_nat = attack_preproc(pil_x_nat).unsqueeze(0)
        ts_x_guid = attack_preproc(pil_x_guid).unsqueeze(0)
        ts_x_adv = attack(ts_x_nat, ts_x_guid)
        pil_x_adv = torchvision.transforms.Compose(to_pil())(ts_x_adv[0])

        img_byte_array = io.BytesIO()
        pil_x_adv.save(img_byte_array, format='PNG')
        img_byte_array.seek(0)

        return StreamingResponse(img_byte_array, media_type='image/png')

    uvicorn.run(app, host=args.host, port=args.port)