| | import torch |
| | from PIL import Image |
| | from .preprocess import preprocess_image |
| | from .utils import load_model |
| |
|
| |
|
| | def predict_with_model(model, inputs): |
| | """Runs inference and returns the predicted class.""" |
| | model.eval() |
| | with torch.no_grad(): |
| | outputs = model(**inputs) |
| | logits = outputs.logits |
| | predicted_class = logits.argmax(dim=-1).item() |
| | return predicted_class |
| |
|
| |
|
| | def predict(image_path): |
| | """Loads an image, preprocesses it, runs the model, and returns the prediction.""" |
| | image = Image.open(image_path).convert("RGB") |
| | inputs = preprocess_image(image) |
| |
|
| | |
| | model = load_model() |
| |
|
| | |
| | device = model.device |
| | inputs = {key: tensor.to(device) for key, tensor in inputs.items()} |
| |
|
| | return predict_with_model(model, inputs) |
| |
|