Spaces:
Sleeping
Sleeping
Commit ·
1958836
1
Parent(s): 6a8cad3
update: export from starry-refactor 2026-02-20 15:25
Browse files
backend/python-services/predictors/unet.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
UNet model implementation.
|
| 3 |
+
Matches the architecture from deep-starry/starry/unet/ for loading .chkpt checkpoints.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class DoubleConv(nn.Module):
|
| 12 |
+
"""(convolution => [BN] => ReLU) * 2"""
|
| 13 |
+
|
| 14 |
+
def __init__(self, in_channels, out_channels, mid_channels=None):
|
| 15 |
+
super().__init__()
|
| 16 |
+
if not mid_channels:
|
| 17 |
+
mid_channels = out_channels
|
| 18 |
+
self.double_conv = nn.Sequential(
|
| 19 |
+
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
|
| 20 |
+
nn.BatchNorm2d(mid_channels),
|
| 21 |
+
nn.ReLU(inplace=True),
|
| 22 |
+
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
|
| 23 |
+
nn.BatchNorm2d(out_channels),
|
| 24 |
+
nn.ReLU(inplace=True),
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
def forward(self, x):
|
| 28 |
+
return self.double_conv(x)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class Down(nn.Module):
|
| 32 |
+
"""Downscaling with maxpool then double conv"""
|
| 33 |
+
|
| 34 |
+
def __init__(self, in_channels, out_channels):
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.maxpool_conv = nn.Sequential(
|
| 37 |
+
nn.MaxPool2d(2),
|
| 38 |
+
DoubleConv(in_channels, out_channels)
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
def forward(self, x):
|
| 42 |
+
return self.maxpool_conv(x)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class Up(nn.Module):
|
| 46 |
+
"""Upscaling then double conv"""
|
| 47 |
+
|
| 48 |
+
def __init__(self, in_channels, out_channels, bilinear=True):
|
| 49 |
+
super().__init__()
|
| 50 |
+
if bilinear:
|
| 51 |
+
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
|
| 52 |
+
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
|
| 53 |
+
else:
|
| 54 |
+
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
|
| 55 |
+
self.conv = DoubleConv(in_channels, out_channels)
|
| 56 |
+
|
| 57 |
+
def forward(self, x1, x2):
|
| 58 |
+
x1 = self.up(x1)
|
| 59 |
+
diffY = x2.size()[2] - x1.size()[2]
|
| 60 |
+
diffX = x2.size()[3] - x1.size()[3]
|
| 61 |
+
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
|
| 62 |
+
diffY // 2, diffY - diffY // 2])
|
| 63 |
+
x = torch.cat([x2, x1], dim=1)
|
| 64 |
+
return self.conv(x)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class OutConv(nn.Module):
|
| 68 |
+
def __init__(self, in_channels, out_channels):
|
| 69 |
+
super().__init__()
|
| 70 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
|
| 71 |
+
|
| 72 |
+
def forward(self, x):
|
| 73 |
+
return self.conv(x)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class UNet(nn.Module):
|
| 77 |
+
def __init__(self, n_channels, n_classes, classify_out=True, bilinear=True, depth=4, init_width=64):
|
| 78 |
+
super().__init__()
|
| 79 |
+
self.n_channels = n_channels
|
| 80 |
+
self.n_classes = n_classes
|
| 81 |
+
self.classify_out = classify_out
|
| 82 |
+
self.depth = depth
|
| 83 |
+
factor = 2 if bilinear else 1
|
| 84 |
+
|
| 85 |
+
self.inc = DoubleConv(n_channels, init_width)
|
| 86 |
+
self.outc = OutConv(init_width, n_classes)
|
| 87 |
+
|
| 88 |
+
downs = []
|
| 89 |
+
ups = []
|
| 90 |
+
|
| 91 |
+
for d in range(depth):
|
| 92 |
+
ic = init_width * (2 ** d)
|
| 93 |
+
oc = ic * 2
|
| 94 |
+
if d == depth - 1:
|
| 95 |
+
oc //= factor
|
| 96 |
+
downs.append(Down(ic, oc))
|
| 97 |
+
|
| 98 |
+
for d in range(depth):
|
| 99 |
+
ic = init_width * (2 ** (depth - d))
|
| 100 |
+
oc = ic // 2
|
| 101 |
+
if d < depth - 1:
|
| 102 |
+
oc //= factor
|
| 103 |
+
ups.append(Up(ic, oc, bilinear))
|
| 104 |
+
|
| 105 |
+
self.downs = nn.ModuleList(modules=downs)
|
| 106 |
+
self.ups = nn.ModuleList(modules=ups)
|
| 107 |
+
|
| 108 |
+
def forward(self, input):
|
| 109 |
+
xs = []
|
| 110 |
+
x = self.inc(input)
|
| 111 |
+
|
| 112 |
+
for down in self.downs:
|
| 113 |
+
xs.append(x)
|
| 114 |
+
x = down(x)
|
| 115 |
+
|
| 116 |
+
xs.reverse()
|
| 117 |
+
|
| 118 |
+
for i, up in enumerate(self.ups):
|
| 119 |
+
xi = xs[i]
|
| 120 |
+
x = up(x, xi)
|
| 121 |
+
|
| 122 |
+
if not self.classify_out:
|
| 123 |
+
return x
|
| 124 |
+
|
| 125 |
+
logits = self.outc(x)
|
| 126 |
+
return logits
|
backend/python-services/services/gauge_service.py
CHANGED
|
@@ -1,13 +1,19 @@
|
|
| 1 |
"""
|
| 2 |
Gauge prediction service.
|
| 3 |
Predicts staff gauge (height and slope) map.
|
|
|
|
| 4 |
"""
|
| 5 |
|
|
|
|
|
|
|
|
|
|
| 6 |
import numpy as np
|
| 7 |
import torch
|
|
|
|
| 8 |
import PIL.Image
|
| 9 |
|
| 10 |
-
from predictors.torchscript_predictor import
|
|
|
|
| 11 |
from common.image_utils import (
|
| 12 |
array_from_image_stream, slice_feature, splice_output_tensor,
|
| 13 |
gauge_to_rgb, encode_image_base64, encode_image_bytes,
|
|
@@ -16,6 +22,80 @@ from common.image_utils import (
|
|
| 16 |
from common.transform import Composer
|
| 17 |
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
class StaffGauge:
|
| 20 |
"""Staff gauge representation."""
|
| 21 |
|
|
@@ -32,14 +112,15 @@ class StaffGauge:
|
|
| 32 |
}
|
| 33 |
|
| 34 |
|
| 35 |
-
class GaugeService
|
| 36 |
-
"""Gauge prediction service
|
| 37 |
|
| 38 |
DEFAULT_TRANS = ['Mono', 'HWC2CHW']
|
| 39 |
DEFAULT_SLICING_WIDTH = 512
|
| 40 |
|
| 41 |
def __init__(self, model_path, device='cuda', trans=None, slicing_width=None):
|
| 42 |
-
|
|
|
|
| 43 |
self.composer = Composer(trans or self.DEFAULT_TRANS)
|
| 44 |
self.slicing_width = slicing_width or self.DEFAULT_SLICING_WIDTH
|
| 45 |
|
|
@@ -70,7 +151,8 @@ class GaugeService(TorchScriptPredictor):
|
|
| 70 |
batch = torch.from_numpy(staves).to(self.device)
|
| 71 |
|
| 72 |
# Inference
|
| 73 |
-
|
|
|
|
| 74 |
|
| 75 |
# Splice output
|
| 76 |
hotmap = splice_output_tensor(output, soft=True) # (channel, height, width)
|
|
|
|
| 1 |
"""
|
| 2 |
Gauge prediction service.
|
| 3 |
Predicts staff gauge (height and slope) map.
|
| 4 |
+
Supports both TorchScript (.pt) and state_dict (.chkpt) model formats.
|
| 5 |
"""
|
| 6 |
|
| 7 |
+
import os
|
| 8 |
+
import logging
|
| 9 |
+
from collections import OrderedDict
|
| 10 |
import numpy as np
|
| 11 |
import torch
|
| 12 |
+
import yaml
|
| 13 |
import PIL.Image
|
| 14 |
|
| 15 |
+
from predictors.torchscript_predictor import resolve_model_path
|
| 16 |
+
from predictors.unet import UNet
|
| 17 |
from common.image_utils import (
|
| 18 |
array_from_image_stream, slice_feature, splice_output_tensor,
|
| 19 |
gauge_to_rgb, encode_image_base64, encode_image_bytes,
|
|
|
|
| 22 |
from common.transform import Composer
|
| 23 |
|
| 24 |
|
| 25 |
+
class _ScoreRegression(torch.nn.Module):
|
| 26 |
+
"""ScoreRegression architecture for loading .chkpt checkpoints."""
|
| 27 |
+
|
| 28 |
+
def __init__(self, in_channels=1, out_channels=2, unet_depth=6, unet_init_width=32):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.backbone = UNet(in_channels, out_channels, depth=unet_depth, init_width=unet_init_width)
|
| 31 |
+
|
| 32 |
+
def forward(self, input):
|
| 33 |
+
return self.backbone(input)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _load_gauge_model(model_path, device):
|
| 37 |
+
"""Load gauge model, handling both TorchScript and state_dict formats."""
|
| 38 |
+
resolved = resolve_model_path(model_path)
|
| 39 |
+
|
| 40 |
+
# Try TorchScript first
|
| 41 |
+
try:
|
| 42 |
+
model = torch.jit.load(resolved, map_location=device)
|
| 43 |
+
model.eval()
|
| 44 |
+
logging.info('GaugeService: TorchScript model loaded: %s', resolved)
|
| 45 |
+
return model
|
| 46 |
+
except Exception as e:
|
| 47 |
+
logging.info('GaugeService: not TorchScript (%s), trying state_dict...', str(e)[:60])
|
| 48 |
+
|
| 49 |
+
# Read model config from .state.yaml
|
| 50 |
+
model_dir = os.path.dirname(resolved)
|
| 51 |
+
state_file = os.path.join(model_dir, '.state.yaml')
|
| 52 |
+
unet_depth = 6
|
| 53 |
+
unet_init_width = 32
|
| 54 |
+
out_channels = 2
|
| 55 |
+
if os.path.exists(state_file):
|
| 56 |
+
with open(state_file, 'r') as f:
|
| 57 |
+
state = yaml.safe_load(f)
|
| 58 |
+
model_args = state.get('model', {}).get('args', {})
|
| 59 |
+
backbone = model_args.get('backbone', {})
|
| 60 |
+
unet_depth = backbone.get('unet_depth', 6)
|
| 61 |
+
unet_init_width = backbone.get('unet_init_width', 32)
|
| 62 |
+
out_channels = model_args.get('out_channels', 2)
|
| 63 |
+
|
| 64 |
+
model = _ScoreRegression(out_channels=out_channels, unet_depth=unet_depth, unet_init_width=unet_init_width)
|
| 65 |
+
checkpoint = torch.load(resolved, map_location=device, weights_only=False)
|
| 66 |
+
|
| 67 |
+
# Handle different checkpoint formats
|
| 68 |
+
state_dict = checkpoint
|
| 69 |
+
if isinstance(checkpoint, dict):
|
| 70 |
+
if 'model' in checkpoint:
|
| 71 |
+
state_dict = checkpoint['model']
|
| 72 |
+
|
| 73 |
+
# Strip common prefixes from training wrapper (ScoreRegressionLoss.deducer.*)
|
| 74 |
+
if isinstance(state_dict, dict):
|
| 75 |
+
cleaned = OrderedDict()
|
| 76 |
+
for key, value in state_dict.items():
|
| 77 |
+
new_key = key
|
| 78 |
+
if new_key.startswith('deducer.'):
|
| 79 |
+
new_key = new_key[len('deducer.'):]
|
| 80 |
+
cleaned[new_key] = value
|
| 81 |
+
# Remove non-model keys (e.g. channel_weights from Loss wrapper)
|
| 82 |
+
cleaned = OrderedDict((k, v) for k, v in cleaned.items()
|
| 83 |
+
if k.startswith('backbone.'))
|
| 84 |
+
state_dict = cleaned
|
| 85 |
+
|
| 86 |
+
model.load_state_dict(state_dict, strict=False)
|
| 87 |
+
model.eval()
|
| 88 |
+
model.to(device)
|
| 89 |
+
|
| 90 |
+
# Log key loading stats
|
| 91 |
+
model_keys = set(model.state_dict().keys())
|
| 92 |
+
loaded_keys = set(state_dict.keys())
|
| 93 |
+
matched = model_keys & loaded_keys
|
| 94 |
+
logging.info('GaugeService: state_dict loaded: %s (%d/%d keys matched, depth=%d, width=%d)',
|
| 95 |
+
resolved, len(matched), len(model_keys), unet_depth, unet_init_width)
|
| 96 |
+
return model
|
| 97 |
+
|
| 98 |
+
|
| 99 |
class StaffGauge:
|
| 100 |
"""Staff gauge representation."""
|
| 101 |
|
|
|
|
| 112 |
}
|
| 113 |
|
| 114 |
|
| 115 |
+
class GaugeService:
|
| 116 |
+
"""Gauge prediction service. Supports TorchScript and state_dict formats."""
|
| 117 |
|
| 118 |
DEFAULT_TRANS = ['Mono', 'HWC2CHW']
|
| 119 |
DEFAULT_SLICING_WIDTH = 512
|
| 120 |
|
| 121 |
def __init__(self, model_path, device='cuda', trans=None, slicing_width=None):
|
| 122 |
+
self.device = device
|
| 123 |
+
self.model = _load_gauge_model(model_path, device)
|
| 124 |
self.composer = Composer(trans or self.DEFAULT_TRANS)
|
| 125 |
self.slicing_width = slicing_width or self.DEFAULT_SLICING_WIDTH
|
| 126 |
|
|
|
|
| 151 |
batch = torch.from_numpy(staves).to(self.device)
|
| 152 |
|
| 153 |
# Inference
|
| 154 |
+
with torch.no_grad():
|
| 155 |
+
output = self.model(batch) # (batch, channel, height, width)
|
| 156 |
|
| 157 |
# Splice output
|
| 158 |
hotmap = splice_output_tensor(output, soft=True) # (channel, height, width)
|
backend/python-services/services/mask_service.py
CHANGED
|
@@ -1,13 +1,18 @@
|
|
| 1 |
"""
|
| 2 |
Mask prediction service.
|
| 3 |
Generates staff foreground/background mask.
|
|
|
|
| 4 |
"""
|
| 5 |
|
|
|
|
|
|
|
| 6 |
import numpy as np
|
| 7 |
import torch
|
|
|
|
| 8 |
import PIL.Image
|
| 9 |
|
| 10 |
-
from predictors.torchscript_predictor import
|
|
|
|
| 11 |
from common.image_utils import (
|
| 12 |
array_from_image_stream, slice_feature, splice_output_tensor,
|
| 13 |
mask_to_alpha, encode_image_base64, encode_image_bytes,
|
|
@@ -16,6 +21,65 @@ from common.image_utils import (
|
|
| 16 |
from common.transform import Composer
|
| 17 |
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
class StaffMask:
|
| 20 |
"""Staff mask representation."""
|
| 21 |
|
|
@@ -32,14 +96,15 @@ class StaffMask:
|
|
| 32 |
}
|
| 33 |
|
| 34 |
|
| 35 |
-
class MaskService
|
| 36 |
-
"""Mask prediction service
|
| 37 |
|
| 38 |
DEFAULT_TRANS = ['Mono', 'HWC2CHW']
|
| 39 |
DEFAULT_SLICING_WIDTH = 512
|
| 40 |
|
| 41 |
def __init__(self, model_path, device='cuda', trans=None, slicing_width=None):
|
| 42 |
-
|
|
|
|
| 43 |
self.composer = Composer(trans or self.DEFAULT_TRANS)
|
| 44 |
self.slicing_width = slicing_width or self.DEFAULT_SLICING_WIDTH
|
| 45 |
|
|
@@ -70,7 +135,8 @@ class MaskService(TorchScriptPredictor):
|
|
| 70 |
batch = torch.from_numpy(staves).to(self.device)
|
| 71 |
|
| 72 |
# Inference
|
| 73 |
-
|
|
|
|
| 74 |
|
| 75 |
# Splice output
|
| 76 |
hotmap = splice_output_tensor(output, soft=True) # (channel, height, width)
|
|
|
|
| 1 |
"""
|
| 2 |
Mask prediction service.
|
| 3 |
Generates staff foreground/background mask.
|
| 4 |
+
Supports both TorchScript (.pt) and state_dict (.chkpt) model formats.
|
| 5 |
"""
|
| 6 |
|
| 7 |
+
import os
|
| 8 |
+
import logging
|
| 9 |
import numpy as np
|
| 10 |
import torch
|
| 11 |
+
import yaml
|
| 12 |
import PIL.Image
|
| 13 |
|
| 14 |
+
from predictors.torchscript_predictor import resolve_model_path
|
| 15 |
+
from predictors.unet import UNet
|
| 16 |
from common.image_utils import (
|
| 17 |
array_from_image_stream, slice_feature, splice_output_tensor,
|
| 18 |
mask_to_alpha, encode_image_base64, encode_image_bytes,
|
|
|
|
| 21 |
from common.transform import Composer
|
| 22 |
|
| 23 |
|
| 24 |
+
class _ScoreWidgetsMask(torch.nn.Module):
|
| 25 |
+
"""ScoreWidgetsMask architecture for loading .chkpt checkpoints."""
|
| 26 |
+
|
| 27 |
+
def __init__(self, in_channels=1, mask_channels=2, unet_depth=5, unet_init_width=32):
|
| 28 |
+
super().__init__()
|
| 29 |
+
self.mask = UNet(in_channels, mask_channels, depth=unet_depth, init_width=unet_init_width)
|
| 30 |
+
|
| 31 |
+
def forward(self, x):
|
| 32 |
+
return torch.sigmoid(self.mask(x))
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _load_mask_model(model_path, device):
|
| 36 |
+
"""Load mask model, handling both TorchScript and state_dict formats."""
|
| 37 |
+
resolved = resolve_model_path(model_path)
|
| 38 |
+
|
| 39 |
+
# Try TorchScript first
|
| 40 |
+
try:
|
| 41 |
+
model = torch.jit.load(resolved, map_location=device)
|
| 42 |
+
model.eval()
|
| 43 |
+
logging.info('MaskService: TorchScript model loaded: %s', resolved)
|
| 44 |
+
return model
|
| 45 |
+
except Exception as e:
|
| 46 |
+
logging.info('MaskService: not TorchScript (%s), trying state_dict...', str(e)[:60])
|
| 47 |
+
|
| 48 |
+
# Read model config from .state.yaml
|
| 49 |
+
model_dir = os.path.dirname(resolved)
|
| 50 |
+
state_file = os.path.join(model_dir, '.state.yaml')
|
| 51 |
+
unet_depth = 5
|
| 52 |
+
unet_init_width = 32
|
| 53 |
+
if os.path.exists(state_file):
|
| 54 |
+
with open(state_file, 'r') as f:
|
| 55 |
+
state = yaml.safe_load(f)
|
| 56 |
+
mask_config = state.get('model', {}).get('args', {}).get('mask', {})
|
| 57 |
+
unet_depth = mask_config.get('unet_depth', 5)
|
| 58 |
+
unet_init_width = mask_config.get('unet_init_width', 32)
|
| 59 |
+
|
| 60 |
+
model = _ScoreWidgetsMask(unet_depth=unet_depth, unet_init_width=unet_init_width)
|
| 61 |
+
checkpoint = torch.load(resolved, map_location=device, weights_only=False)
|
| 62 |
+
|
| 63 |
+
# Handle different checkpoint formats
|
| 64 |
+
state_dict = checkpoint
|
| 65 |
+
if isinstance(checkpoint, dict):
|
| 66 |
+
if 'model' in checkpoint:
|
| 67 |
+
state_dict = checkpoint['model']
|
| 68 |
+
|
| 69 |
+
# ScoreWidgetsMask saves as {'mask': {UNet weights}}
|
| 70 |
+
if isinstance(state_dict, dict) and 'mask' in state_dict:
|
| 71 |
+
model.mask.load_state_dict(state_dict['mask'])
|
| 72 |
+
else:
|
| 73 |
+
# Try loading directly (may have 'mask.' prefix from nn.Module default)
|
| 74 |
+
model.load_state_dict(state_dict, strict=False)
|
| 75 |
+
|
| 76 |
+
model.eval()
|
| 77 |
+
model.to(device)
|
| 78 |
+
logging.info('MaskService: state_dict loaded: %s (depth=%d, width=%d)',
|
| 79 |
+
resolved, unet_depth, unet_init_width)
|
| 80 |
+
return model
|
| 81 |
+
|
| 82 |
+
|
| 83 |
class StaffMask:
|
| 84 |
"""Staff mask representation."""
|
| 85 |
|
|
|
|
| 96 |
}
|
| 97 |
|
| 98 |
|
| 99 |
+
class MaskService:
|
| 100 |
+
"""Mask prediction service. Supports TorchScript and state_dict formats."""
|
| 101 |
|
| 102 |
DEFAULT_TRANS = ['Mono', 'HWC2CHW']
|
| 103 |
DEFAULT_SLICING_WIDTH = 512
|
| 104 |
|
| 105 |
def __init__(self, model_path, device='cuda', trans=None, slicing_width=None):
|
| 106 |
+
self.device = device
|
| 107 |
+
self.model = _load_mask_model(model_path, device)
|
| 108 |
self.composer = Composer(trans or self.DEFAULT_TRANS)
|
| 109 |
self.slicing_width = slicing_width or self.DEFAULT_SLICING_WIDTH
|
| 110 |
|
|
|
|
| 135 |
batch = torch.from_numpy(staves).to(self.device)
|
| 136 |
|
| 137 |
# Inference
|
| 138 |
+
with torch.no_grad():
|
| 139 |
+
output = self.model(batch) # (batch, channel, height, width)
|
| 140 |
|
| 141 |
# Splice output
|
| 142 |
hotmap = splice_output_tensor(output, soft=True) # (channel, height, width)
|