| import torch |
| import onnx |
| import onnxruntime as ort |
| import os |
| from transformers import AutoModelForImageSegmentation |
|
|
| |
| onnx_model_path = "background_remover.onnx" |
|
|
| |
| if not os.path.exists(onnx_model_path): |
| print("π Converting PyTorch model to ONNX...") |
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| |
| model = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-2.0", trust_remote_code=True).to(device) |
| model.eval() |
|
|
| |
| dummy_input = torch.randn(1, 3, 512, 512).to(device) |
|
|
| |
| with torch.no_grad(): |
| torch.onnx.export( |
| model, |
| dummy_input, |
| onnx_model_path, |
| opset_version=13, |
| input_names=["input"], |
| output_names=["output"], |
| do_constant_folding=True, |
| dynamic_axes={ |
| "input": {0: "batch_size", 2: "height", 3: "width"}, |
| "output": {0: "batch_size", 2: "height", 3: "width"} |
| } |
| ) |
|
|
| print(f"β
ONNX model successfully saved to {onnx_model_path}") |
| else: |
| print("π ONNX model already exists. Skipping conversion.") |
|
|
| |
| try: |
| providers = ["CUDAExecutionProvider"] if torch.cuda.is_available() else ["CPUExecutionProvider"] |
| session = ort.InferenceSession(onnx_model_path, providers=providers) |
| print(f"β
ONNX model successfully loaded using {providers[0]}!") |
| except Exception as e: |
| print(f"β ONNX model validation failed: {e}") |
|
|