RTS / inference.py
Ben Galewsky
Add example
d1eb2ce
import torch
import requests
import tempfile
from torchvision.io import read_image
def preprocess(image):
from torchvision.io import read_image
from torchvision.transforms import v2 as T
transforms = []
transforms.append(T.ToDtype(torch.float, scale=True))
transforms.append(T.ToPureTensor())
transform = T.Compose(transforms)
image = (255.0 * (image - image.min()) / (image.max() - image.min())).to(
torch.uint8
)
image = image[:3, ...]
transformed_image = transform(image)
x = torch.unsqueeze(transformed_image, 0)
return x # [:3, ...]
def filter_predictions(pred, score_threshold=0.5):
keep = pred["scores"] > score_threshold
return {k: v[keep] for k, v in pred.items()}
model = torch.load("model.pth", map_location=torch.device('cpu'))
image_url = "https://github.com/cyber2a/Cyber2A-RTS-ToyModel/blob/main/data/images/valtest_nitze_008.jpg?raw=true"
response = requests.get(image_url)
with tempfile.NamedTemporaryFile(delete=True, suffix='.jpg') as tmp_file:
# Write the content to the temporary file
tmp_file.write(response.content)
tmp_file_path = tmp_file.name
image = read_image(tmp_file_path)
scaled_tensor = preprocess(image)
with torch.no_grad():
output = model(scaled_tensor)
print(filter_predictions(output[0]))