|
|
import torch |
|
|
import requests |
|
|
import tempfile |
|
|
from torchvision.io import read_image |
|
|
|
|
|
def preprocess(image): |
|
|
from torchvision.io import read_image |
|
|
from torchvision.transforms import v2 as T |
|
|
transforms = [] |
|
|
|
|
|
transforms.append(T.ToDtype(torch.float, scale=True)) |
|
|
transforms.append(T.ToPureTensor()) |
|
|
transform = T.Compose(transforms) |
|
|
|
|
|
image = (255.0 * (image - image.min()) / (image.max() - image.min())).to( |
|
|
torch.uint8 |
|
|
) |
|
|
image = image[:3, ...] |
|
|
transformed_image = transform(image) |
|
|
|
|
|
x = torch.unsqueeze(transformed_image, 0) |
|
|
|
|
|
return x |
|
|
|
|
|
def filter_predictions(pred, score_threshold=0.5): |
|
|
keep = pred["scores"] > score_threshold |
|
|
return {k: v[keep] for k, v in pred.items()} |
|
|
|
|
|
|
|
|
model = torch.load("model.pth", map_location=torch.device('cpu')) |
|
|
image_url = "https://github.com/cyber2a/Cyber2A-RTS-ToyModel/blob/main/data/images/valtest_nitze_008.jpg?raw=true" |
|
|
response = requests.get(image_url) |
|
|
with tempfile.NamedTemporaryFile(delete=True, suffix='.jpg') as tmp_file: |
|
|
|
|
|
tmp_file.write(response.content) |
|
|
tmp_file_path = tmp_file.name |
|
|
image = read_image(tmp_file_path) |
|
|
|
|
|
scaled_tensor = preprocess(image) |
|
|
with torch.no_grad(): |
|
|
output = model(scaled_tensor) |
|
|
|
|
|
|
|
|
print(filter_predictions(output[0])) |
|
|
|