Miczu212 commited on
Commit
ebb8759
verified
1 Parent(s): 8509f0a

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +45 -0
  2. mnist_resnet18.onnx +3 -0
  3. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import onnxruntime as ort
2
+ from PIL import Image
3
+ import numpy as np
4
+ from torchvision import transforms
5
+ import gradio as gr
6
+
7
+ # --- 1. Za艂aduj model ONNX ---
8
+ session = ort.InferenceSession("mnist_resnet18.onnx")
9
+
10
+ # --- 2. Transformacje obrazu takie jak przy trenowaniu ---
11
+ transform = transforms.Compose([
12
+ transforms.Grayscale(num_output_channels=3), # bo ResNet18 oczekuje 3 kana艂贸w
13
+ transforms.Resize((224, 224)),
14
+ transforms.ToTensor(),
15
+ transforms.Normalize([0.485, 0.456, 0.406],
16
+ [0.229, 0.224, 0.225])
17
+ ])
18
+
19
+ # --- 3. Funkcja predykcji ---
20
+ def predict(image):
21
+ """
22
+ image: PIL.Image
23
+ zwraca: przewidziana cyfra (0-9)
24
+ """
25
+ # transformacja + dodanie batch dimension
26
+ img_t = transform(image).unsqueeze(0).numpy()
27
+
28
+ # inference ONNX
29
+ outputs = session.run(None, {"input": img_t})
30
+
31
+ # wyb贸r klasy o najwi臋kszym prawdopodobie艅stwie
32
+ pred = int(np.argmax(outputs[0]))
33
+ return pred
34
+
35
+ # --- 4. Gradio interface ---
36
+ iface = gr.Interface(
37
+ fn=predict,
38
+ inputs=gr.Image(type="pil"),
39
+ outputs="number",
40
+ title="MNIST ResNet18 ONNX API",
41
+ description="Prze艣lij obraz cyfry 0-9, model ResNet18 (ONNX) zwr贸ci predykcj臋."
42
+ )
43
+
44
+ # --- 5. Uruchomienie Space ---
45
+ iface.launch()
mnist_resnet18.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a8ab4dbd54edbb7c28b7394ded11a98a3846dcdb68362c3fd9b2e5e56a1d792a
3
+ size 44717065
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ onnx
2
+ onnxruntime
3
+ Pillow
4
+ gradio