|
|
--- |
|
|
datasets: |
|
|
- gOLIVES/CRACKS |
|
|
--- |
|
|
# UNet for Seisemic Image Detection |
|
|
## Usage |
|
|
This model is based on the UNet model with a ResNet50 encoder. To load the model weights and use the model, you can run the following code: |
|
|
```python |
|
|
from torchinfo import summary |
|
|
from safetensors.torch import load_file |
|
|
import segmentation_models_pytorch as smp |
|
|
from huggingface_hub import hf_hub_download |
|
|
import torch |
|
|
|
|
|
# Define the device |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
# Create the model |
|
|
model = smp.Unet(encoder_name="resnet50", encoder_weights=None, in_channels=3, classes=1) |
|
|
|
|
|
# Download the model weights from Hugging Face Hub |
|
|
weights_path = hf_hub_download( |
|
|
repo_id="gOLIVES/UNet_Synth2CRACKS", |
|
|
filename="model.safetensors", |
|
|
cache_dir="./cache_dir" |
|
|
) |
|
|
|
|
|
# Load weights |
|
|
weights = load_file(weights_path, device=device) |
|
|
model.load_state_dict(weights) |
|
|
|
|
|
# Move model to device |
|
|
model = model.to(device) |
|
|
|
|
|
# Display model summary |
|
|
summary(model, (1, 3, 256, 256)) |
|
|
``` |
|
|
|