CPU / app.py
Miczu212's picture
Upload 3 files
ebb8759 verified
import onnxruntime as ort
from PIL import Image
import numpy as np
from torchvision import transforms
import gradio as gr
# --- 1. Za艂aduj model ONNX ---
session = ort.InferenceSession("mnist_resnet18.onnx")
# --- 2. Transformacje obrazu takie jak przy trenowaniu ---
transform = transforms.Compose([
transforms.Grayscale(num_output_channels=3), # bo ResNet18 oczekuje 3 kana艂贸w
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
# --- 3. Funkcja predykcji ---
def predict(image):
"""
image: PIL.Image
zwraca: przewidziana cyfra (0-9)
"""
# transformacja + dodanie batch dimension
img_t = transform(image).unsqueeze(0).numpy()
# inference ONNX
outputs = session.run(None, {"input": img_t})
# wyb贸r klasy o najwi臋kszym prawdopodobie艅stwie
pred = int(np.argmax(outputs[0]))
return pred
# --- 4. Gradio interface ---
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs="number",
title="MNIST ResNet18 ONNX API",
description="Prze艣lij obraz cyfry 0-9, model ResNet18 (ONNX) zwr贸ci predykcj臋."
)
# --- 5. Uruchomienie Space ---
iface.launch()