File size: 1,338 Bytes
d1eb2ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import torch
import requests
import tempfile
from torchvision.io import read_image

def preprocess(image):
    from torchvision.io import read_image
    from torchvision.transforms import v2 as T
    transforms = []

    transforms.append(T.ToDtype(torch.float, scale=True))
    transforms.append(T.ToPureTensor())
    transform =  T.Compose(transforms)

    image = (255.0 * (image - image.min()) / (image.max() - image.min())).to(
        torch.uint8
    )
    image = image[:3, ...]
    transformed_image = transform(image)

    x = torch.unsqueeze(transformed_image, 0)

    return x # [:3, ...]

def filter_predictions(pred, score_threshold=0.5):
    keep = pred["scores"] > score_threshold
    return {k: v[keep] for k, v in pred.items()}


model = torch.load("model.pth", map_location=torch.device('cpu'))
image_url = "https://github.com/cyber2a/Cyber2A-RTS-ToyModel/blob/main/data/images/valtest_nitze_008.jpg?raw=true"
response = requests.get(image_url)
with tempfile.NamedTemporaryFile(delete=True, suffix='.jpg') as tmp_file:
    # Write the content to the temporary file
    tmp_file.write(response.content)
    tmp_file_path = tmp_file.name
    image = read_image(tmp_file_path)

    scaled_tensor = preprocess(image)
    with torch.no_grad():
        output = model(scaled_tensor)


    print(filter_predictions(output[0]))