ONNX
English
dhvazquez commited on
Commit
f33bb6e
·
verified ·
1 Parent(s): bf96bdd

Upload 4 files

Browse files
card_segmentation.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:240b81e1de091726a689003e729f7a3caa593d476cfa300c86f2f3ed1753be60
3
+ size 16806616
card_segmentation.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:031c4491b28767baecb111c828244d29c05982c5678752e90a1095cf6c66f604
3
+ size 17371615
card_segmentation_state_dict.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bac98f25a04ad479d80ec10f66fe6365c8471fe448aff10da30c0e4722ed1b61
3
+ size 17025002
inference_example.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Example inference script for card_segmentation model.
3
+ """
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import cv2
7
+ import numpy as np
8
+ from PIL import Image
9
+ import onnxruntime as ort
10
+
11
+ def preprocess_image(image_path, target_size=(320, 240)):
12
+ """
13
+ Preprocess image for model inference.
14
+
15
+ Args:
16
+ image_path (str): Path to input image
17
+ target_size (tuple): Target image size (H, W)
18
+
19
+ Returns:
20
+ torch.Tensor: Preprocessed image tensor
21
+ """
22
+ # Load image
23
+ image = cv2.imread(image_path)
24
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
25
+
26
+ # Resize
27
+ image = cv2.resize(image, (target_size[1], target_size[0]))
28
+
29
+ # Normalize
30
+ image = image.astype(np.float32) / 255.0
31
+ image = (image - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])
32
+
33
+ # Convert to tensor and add batch dimension
34
+ image_tensor = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0)
35
+
36
+ return image_tensor
37
+
38
+ def postprocess_output(output):
39
+ """
40
+ Postprocess model output to get segmentation mask.
41
+
42
+ Args:
43
+ output: Model output tensor
44
+
45
+ Returns:
46
+ np.ndarray: Binary segmentation mask
47
+ """
48
+ # Apply softmax and get predictions
49
+ probs = F.softmax(output, dim=1)
50
+ pred_mask = torch.argmax(probs, dim=1)
51
+
52
+ return pred_mask.cpu().numpy()[0]
53
+
54
+ def inference_pytorch(model_path, image_path):
55
+ """
56
+ Run inference using PyTorch model.
57
+ """
58
+ # Load model
59
+ model = torch.jit.load(model_path, map_location='cpu')
60
+ model.eval()
61
+
62
+ # Preprocess image
63
+ input_tensor = preprocess_image(image_path)
64
+
65
+ # Run inference
66
+ with torch.no_grad():
67
+ output = model(input_tensor)
68
+
69
+ # Postprocess
70
+ mask = postprocess_output(output)
71
+
72
+ return mask
73
+
74
+ def inference_onnx(model_path, image_path):
75
+ """
76
+ Run inference using ONNX model.
77
+ """
78
+ # Load ONNX model
79
+ session = ort.InferenceSession(model_path)
80
+
81
+ # Preprocess image
82
+ input_tensor = preprocess_image(image_path)
83
+ input_array = input_tensor.numpy()
84
+
85
+ # Run inference
86
+ input_name = session.get_inputs()[0].name
87
+ output = session.run(None, {input_name: input_array})[0]
88
+
89
+ # Postprocess
90
+ output_tensor = torch.from_numpy(output)
91
+ mask = postprocess_output(output_tensor)
92
+
93
+ return mask
94
+
95
+ def save_mask(mask, output_path):
96
+ """Save segmentation mask as image."""
97
+ # Convert to 0-255 range
98
+ mask_image = (mask * 255).astype(np.uint8)
99
+ cv2.imwrite(output_path, mask_image)
100
+
101
+ if __name__ == "__main__":
102
+ import argparse
103
+
104
+ parser = argparse.ArgumentParser(description='Run inference on card segmentation model')
105
+ parser.add_argument('--model', type=str, required=True, help='Path to model file')
106
+ parser.add_argument('--image', type=str, required=True, help='Path to input image')
107
+ parser.add_argument('--output', type=str, default='output_mask.png', help='Output mask path')
108
+ parser.add_argument('--format', type=str, choices=['pytorch', 'onnx'], default='onnx',
109
+ help='Model format')
110
+
111
+ args = parser.parse_args()
112
+
113
+ # Run inference
114
+ if args.format == 'pytorch':
115
+ mask = inference_pytorch(args.model, args.image)
116
+ else:
117
+ mask = inference_onnx(args.model, args.image)
118
+
119
+ # Save result
120
+ save_mask(mask, args.output)
121
+ print(f"Segmentation mask saved to: {args.output}")