import onnxruntime as ort from PIL import Image import numpy as np from torchvision import transforms import gradio as gr # --- 1. Załaduj model ONNX --- session = ort.InferenceSession("mnist_resnet18.onnx") # --- 2. Transformacje obrazu takie jak przy trenowaniu --- transform = transforms.Compose([ transforms.Grayscale(num_output_channels=3), # bo ResNet18 oczekuje 3 kanałów transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # --- 3. Funkcja predykcji --- def predict(image): """ image: PIL.Image zwraca: przewidziana cyfra (0-9) """ # transformacja + dodanie batch dimension img_t = transform(image).unsqueeze(0).numpy() # inference ONNX outputs = session.run(None, {"input": img_t}) # wybór klasy o największym prawdopodobieństwie pred = int(np.argmax(outputs[0])) return pred # --- 4. Gradio interface --- iface = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs="number", title="MNIST ResNet18 ONNX API", description="Prześlij obraz cyfry 0-9, model ResNet18 (ONNX) zwróci predykcję." ) # --- 5. Uruchomienie Space --- iface.launch()