Commit
·
d9b2258
1
Parent(s):
08d7494
Create onnx_kv_inject.py
Browse files- onnx_kv_inject.py +17 -0
onnx_kv_inject.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import click
|
| 2 |
+
import os
|
| 3 |
+
import onnx
|
| 4 |
+
from sparseml.exporters.kv_cache_injector import KeyValueCacheInjector
|
| 5 |
+
from sparseml.onnx.utils import ONNXGraph
|
| 6 |
+
@click.command()
|
| 7 |
+
@click.option('--input-file', help='Path to the input ONNX model file')
|
| 8 |
+
@click.option('--output-file', help='Output path for the modified model')
|
| 9 |
+
def modify_model(input_file, output_file):
|
| 10 |
+
model = onnx.load(input_file, load_external_data=False)
|
| 11 |
+
model = KeyValueCacheInjector(model_path=os.path.dirname(input_file)).apply(model)
|
| 12 |
+
graph = ONNXGraph(model)
|
| 13 |
+
graph.delete_orphaned_node_branches()
|
| 14 |
+
onnx.save(model, output_file)
|
| 15 |
+
print(f"Modified model saved to: {output_file}")
|
| 16 |
+
if __name__ == '__main__':
|
| 17 |
+
modify_model()
|