ok / app.py
Abhi2025's picture
Update app.py
bdee10b verified
import torch
import onnx
import onnxruntime as ort
import os
from transformers import AutoModelForImageSegmentation
# Set model path
onnx_model_path = "background_remover.onnx"
# Check if model already exists (to avoid re-conversion)
if not os.path.exists(onnx_model_path):
print("πŸš€ Converting PyTorch model to ONNX...")
# Check GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the pre-trained model and move to GPU
model = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-2.0", trust_remote_code=True).to(device)
model.eval()
# Create a dummy input tensor on GPU
dummy_input = torch.randn(1, 3, 512, 512).to(device)
# Export the model to ONNX format
with torch.no_grad(): # Prevent gradients computation
torch.onnx.export(
model,
dummy_input,
onnx_model_path,
opset_version=13, # Use a stable opset version
input_names=["input"],
output_names=["output"],
do_constant_folding=True, # Optimize the model
dynamic_axes={
"input": {0: "batch_size", 2: "height", 3: "width"},
"output": {0: "batch_size", 2: "height", 3: "width"}
}
)
print(f"βœ… ONNX model successfully saved to {onnx_model_path}")
else:
print("πŸ”„ ONNX model already exists. Skipping conversion.")
# πŸ”Ή Validate the ONNX model
try:
providers = ["CUDAExecutionProvider"] if torch.cuda.is_available() else ["CPUExecutionProvider"]
session = ort.InferenceSession(onnx_model_path, providers=providers)
print(f"βœ… ONNX model successfully loaded using {providers[0]}!")
except Exception as e:
print(f"❌ ONNX model validation failed: {e}")