ibelem commited on
Commit
96a90aa
·
verified ·
1 Parent(s): 02564b6

Upload 2 files

Browse files
Files changed (2) hide show
  1. README.md +13 -0
  2. optimize_sam_decoder.py +98 -0
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ pipeline_tag: text-to-image
3
+ inference: false
4
+ ---
5
+ # Model summary
6
+ This Segment Anything Model has been optimized to work with WebNN. This model is licensed under the [Apache-2.0](https://github.com/facebookresearch/segment-anything?tab=Apache-2.0-1-ov-file#readme) License. For terms of use, please visit the [Code of Conduct](https://github.com/facebookresearch/segment-anything/blob/main/CODE_OF_CONDUCT.md). If you comply with the license and terms of use, you have the rights described therin. By using this Model, you accept the terms.
7
+
8
+ Segment-Anything-WebNN is meant to be used with the corresponding sample [here](https://microsoft.github.io/webnn-developer-preview/).
9
+
10
+ # Model changes
11
+ Segment-Anything-Model-WebNN is an ONNX version of the Segment Anything Model, and is optimized for WebNN by using static input shapes and eliminates operators that are not in use.
12
+
13
+ Please find the original Segment Anything Model [here](https://github.com/facebookresearch/segment-anything).
optimize_sam_decoder.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import onnx
2
+ from onnx import helper, TensorProto
3
+ import numpy as np
4
+
5
+ # Optimize the SAM decoder model to accept dynamic original image size input
6
+ # to enable constant folding with freeDimensionOverride option in onnxruntime.
7
+ def optimize_sa_model(model_path: str, out_path: str, is_fp16: bool = False):
8
+ model = onnx.load(model_path)
9
+
10
+ graph = model.graph
11
+
12
+ old_input_name="orig_im_size"
13
+ new_input_name="orig_im_size_shape"
14
+
15
+ # 1) Find old input
16
+ old_inputs = {vi.name: vi for vi in graph.input}
17
+ assert old_input_name in old_inputs, f"Input {old_input_name} not found"
18
+ old_vi = old_inputs[old_input_name]
19
+
20
+ # 2) Remove old input and add new input
21
+ graph.input.remove(old_vi)
22
+ # new input with shape: [height, width]
23
+ new_input_vi = helper.make_tensor_value_info(new_input_name, TensorProto.FLOAT, ["height", "width"])
24
+ graph.input.extend([new_input_vi])
25
+
26
+ # Check if new_input_name exists in the graph inputs
27
+ if new_input_name not in [input.name for input in graph.input]:
28
+ raise ValueError(f"Input '{new_input_name}' does not exist in the graph inputs.")
29
+
30
+ # 3) Insert Shape node: Shape(X) -> shape_X(INT64 1D tensor [H, W])
31
+ shape_output_name = old_input_name # keep the same name as old input
32
+ shape_node = helper.make_node(
33
+ "Shape",
34
+ inputs=[new_input_name],
35
+ outputs=[shape_output_name],
36
+ name="shape_of_orig_im_size_shape"
37
+ )
38
+ # Insert the Shape node
39
+ graph.node.insert(0, shape_node)
40
+
41
+ # 4) The origin input dtype is not INT64, need to add Cast node
42
+ # But the original model has already a Cast node after the input, ignore it
43
+
44
+ if is_fp16:
45
+ # For fp16 model, since CPU kernel doesn't support constant folding
46
+ # for fp16 data type, we need to convert some fp16 constants and input/output info to fp32
47
+ fp16_constants = ["/Constant_85", "/Constant_86"]
48
+
49
+ # Convert fp16 constants in fp16_constants to fp32
50
+ for node in graph.node:
51
+ if node.op_type == "Constant" and node.name in fp16_constants:
52
+ print(node.name)
53
+ # Locate the "value" attribute of the Constant node
54
+ for attr in node.attribute:
55
+ if attr.name == "value":
56
+ # Extract the tensor value
57
+ tensor = onnx.numpy_helper.to_array(attr.t)
58
+
59
+ # Convert the tensor to the target data type
60
+ new_tensor = tensor.astype(np.float32)
61
+
62
+ # Create a new ONNX tensor with the updated data type
63
+ attr.t.CopyFrom(onnx.numpy_helper.from_array(new_tensor))
64
+ break
65
+ else:
66
+ raise ValueError(f"Constant node '{node.name}' does not have a 'value' attribute.")
67
+
68
+ fp16_nodes = ["/ReduceMax", "/Reciprocal", "/Mul_19", "/Mul_20", "/Add_11", "/Floor"]
69
+ # Change fp16 nodes in fp16_nodes to fp32
70
+ for node in graph.node:
71
+ if node.name in fp16_nodes:
72
+ print(f"Processing node: {node.name}")
73
+ for input_name in node.input:
74
+ for value_info in graph.value_info:
75
+ if value_info.name == input_name:
76
+ value_info.type.tensor_type.elem_type = TensorProto.FLOAT
77
+ print(f" - Change input: {input_name} to fp32")
78
+
79
+ for output_name in node.output:
80
+ for value_info in graph.value_info:
81
+ if value_info.name == output_name:
82
+ value_info.type.tensor_type.elem_type = TensorProto.FLOAT
83
+ print(f" - Change output: {output_name} to fp32")
84
+
85
+ # Change /Cast_9 to fp32
86
+ for node in graph.node:
87
+ if node.name == "/Cast_9":
88
+ node.attribute[0].i = TensorProto.FLOAT
89
+ print(f"Changed /Cast_9 to fp32")
90
+ break
91
+ onnx.checker.check_model(model)
92
+ onnx.save(model, out_path)
93
+ print(f"Saved to {out_path}")
94
+
95
+ # the original int8 decoder model: https://huggingface.co/schmuell/sam-b-fp16/blob/main/sam_vit_b-decoder-int8.onnx
96
+ # optimize_sa_model("sam_vit_b-decoder-int8.onnx", "sam_vit_b-decoder-int8-orig-img-size-dynamic.onnx", False)
97
+ # the original fp32 decoder model: https://huggingface.co/schmuell/sam-b-fp16/blob/main/sam_vit_b_01ec64.decoder.onnx
98
+ optimize_sa_model("sam_vit_b_01ec64.decoder-fp16.onnx", "sam_vit_b_01ec64.decoder-orig-img-size-dynamic-fp16.onnx", True)