k-l-lambda commited on
Commit
1958836
·
1 Parent(s): 6a8cad3

update: export from starry-refactor 2026-02-20 15:25

Browse files
backend/python-services/predictors/unet.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ UNet model implementation.
3
+ Matches the architecture from deep-starry/starry/unet/ for loading .chkpt checkpoints.
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+
11
+ class DoubleConv(nn.Module):
12
+ """(convolution => [BN] => ReLU) * 2"""
13
+
14
+ def __init__(self, in_channels, out_channels, mid_channels=None):
15
+ super().__init__()
16
+ if not mid_channels:
17
+ mid_channels = out_channels
18
+ self.double_conv = nn.Sequential(
19
+ nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
20
+ nn.BatchNorm2d(mid_channels),
21
+ nn.ReLU(inplace=True),
22
+ nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
23
+ nn.BatchNorm2d(out_channels),
24
+ nn.ReLU(inplace=True),
25
+ )
26
+
27
+ def forward(self, x):
28
+ return self.double_conv(x)
29
+
30
+
31
+ class Down(nn.Module):
32
+ """Downscaling with maxpool then double conv"""
33
+
34
+ def __init__(self, in_channels, out_channels):
35
+ super().__init__()
36
+ self.maxpool_conv = nn.Sequential(
37
+ nn.MaxPool2d(2),
38
+ DoubleConv(in_channels, out_channels)
39
+ )
40
+
41
+ def forward(self, x):
42
+ return self.maxpool_conv(x)
43
+
44
+
45
+ class Up(nn.Module):
46
+ """Upscaling then double conv"""
47
+
48
+ def __init__(self, in_channels, out_channels, bilinear=True):
49
+ super().__init__()
50
+ if bilinear:
51
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
52
+ self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
53
+ else:
54
+ self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
55
+ self.conv = DoubleConv(in_channels, out_channels)
56
+
57
+ def forward(self, x1, x2):
58
+ x1 = self.up(x1)
59
+ diffY = x2.size()[2] - x1.size()[2]
60
+ diffX = x2.size()[3] - x1.size()[3]
61
+ x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
62
+ diffY // 2, diffY - diffY // 2])
63
+ x = torch.cat([x2, x1], dim=1)
64
+ return self.conv(x)
65
+
66
+
67
+ class OutConv(nn.Module):
68
+ def __init__(self, in_channels, out_channels):
69
+ super().__init__()
70
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
71
+
72
+ def forward(self, x):
73
+ return self.conv(x)
74
+
75
+
76
+ class UNet(nn.Module):
77
+ def __init__(self, n_channels, n_classes, classify_out=True, bilinear=True, depth=4, init_width=64):
78
+ super().__init__()
79
+ self.n_channels = n_channels
80
+ self.n_classes = n_classes
81
+ self.classify_out = classify_out
82
+ self.depth = depth
83
+ factor = 2 if bilinear else 1
84
+
85
+ self.inc = DoubleConv(n_channels, init_width)
86
+ self.outc = OutConv(init_width, n_classes)
87
+
88
+ downs = []
89
+ ups = []
90
+
91
+ for d in range(depth):
92
+ ic = init_width * (2 ** d)
93
+ oc = ic * 2
94
+ if d == depth - 1:
95
+ oc //= factor
96
+ downs.append(Down(ic, oc))
97
+
98
+ for d in range(depth):
99
+ ic = init_width * (2 ** (depth - d))
100
+ oc = ic // 2
101
+ if d < depth - 1:
102
+ oc //= factor
103
+ ups.append(Up(ic, oc, bilinear))
104
+
105
+ self.downs = nn.ModuleList(modules=downs)
106
+ self.ups = nn.ModuleList(modules=ups)
107
+
108
+ def forward(self, input):
109
+ xs = []
110
+ x = self.inc(input)
111
+
112
+ for down in self.downs:
113
+ xs.append(x)
114
+ x = down(x)
115
+
116
+ xs.reverse()
117
+
118
+ for i, up in enumerate(self.ups):
119
+ xi = xs[i]
120
+ x = up(x, xi)
121
+
122
+ if not self.classify_out:
123
+ return x
124
+
125
+ logits = self.outc(x)
126
+ return logits
backend/python-services/services/gauge_service.py CHANGED
@@ -1,13 +1,19 @@
1
  """
2
  Gauge prediction service.
3
  Predicts staff gauge (height and slope) map.
 
4
  """
5
 
 
 
 
6
  import numpy as np
7
  import torch
 
8
  import PIL.Image
9
 
10
- from predictors.torchscript_predictor import TorchScriptPredictor
 
11
  from common.image_utils import (
12
  array_from_image_stream, slice_feature, splice_output_tensor,
13
  gauge_to_rgb, encode_image_base64, encode_image_bytes,
@@ -16,6 +22,80 @@ from common.image_utils import (
16
  from common.transform import Composer
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  class StaffGauge:
20
  """Staff gauge representation."""
21
 
@@ -32,14 +112,15 @@ class StaffGauge:
32
  }
33
 
34
 
35
- class GaugeService(TorchScriptPredictor):
36
- """Gauge prediction service using TorchScript model."""
37
 
38
  DEFAULT_TRANS = ['Mono', 'HWC2CHW']
39
  DEFAULT_SLICING_WIDTH = 512
40
 
41
  def __init__(self, model_path, device='cuda', trans=None, slicing_width=None):
42
- super().__init__(model_path, device)
 
43
  self.composer = Composer(trans or self.DEFAULT_TRANS)
44
  self.slicing_width = slicing_width or self.DEFAULT_SLICING_WIDTH
45
 
@@ -70,7 +151,8 @@ class GaugeService(TorchScriptPredictor):
70
  batch = torch.from_numpy(staves).to(self.device)
71
 
72
  # Inference
73
- output = self.run_inference(batch) # (batch, channel, height, width)
 
74
 
75
  # Splice output
76
  hotmap = splice_output_tensor(output, soft=True) # (channel, height, width)
 
1
  """
2
  Gauge prediction service.
3
  Predicts staff gauge (height and slope) map.
4
+ Supports both TorchScript (.pt) and state_dict (.chkpt) model formats.
5
  """
6
 
7
+ import os
8
+ import logging
9
+ from collections import OrderedDict
10
  import numpy as np
11
  import torch
12
+ import yaml
13
  import PIL.Image
14
 
15
+ from predictors.torchscript_predictor import resolve_model_path
16
+ from predictors.unet import UNet
17
  from common.image_utils import (
18
  array_from_image_stream, slice_feature, splice_output_tensor,
19
  gauge_to_rgb, encode_image_base64, encode_image_bytes,
 
22
  from common.transform import Composer
23
 
24
 
25
+ class _ScoreRegression(torch.nn.Module):
26
+ """ScoreRegression architecture for loading .chkpt checkpoints."""
27
+
28
+ def __init__(self, in_channels=1, out_channels=2, unet_depth=6, unet_init_width=32):
29
+ super().__init__()
30
+ self.backbone = UNet(in_channels, out_channels, depth=unet_depth, init_width=unet_init_width)
31
+
32
+ def forward(self, input):
33
+ return self.backbone(input)
34
+
35
+
36
+ def _load_gauge_model(model_path, device):
37
+ """Load gauge model, handling both TorchScript and state_dict formats."""
38
+ resolved = resolve_model_path(model_path)
39
+
40
+ # Try TorchScript first
41
+ try:
42
+ model = torch.jit.load(resolved, map_location=device)
43
+ model.eval()
44
+ logging.info('GaugeService: TorchScript model loaded: %s', resolved)
45
+ return model
46
+ except Exception as e:
47
+ logging.info('GaugeService: not TorchScript (%s), trying state_dict...', str(e)[:60])
48
+
49
+ # Read model config from .state.yaml
50
+ model_dir = os.path.dirname(resolved)
51
+ state_file = os.path.join(model_dir, '.state.yaml')
52
+ unet_depth = 6
53
+ unet_init_width = 32
54
+ out_channels = 2
55
+ if os.path.exists(state_file):
56
+ with open(state_file, 'r') as f:
57
+ state = yaml.safe_load(f)
58
+ model_args = state.get('model', {}).get('args', {})
59
+ backbone = model_args.get('backbone', {})
60
+ unet_depth = backbone.get('unet_depth', 6)
61
+ unet_init_width = backbone.get('unet_init_width', 32)
62
+ out_channels = model_args.get('out_channels', 2)
63
+
64
+ model = _ScoreRegression(out_channels=out_channels, unet_depth=unet_depth, unet_init_width=unet_init_width)
65
+ checkpoint = torch.load(resolved, map_location=device, weights_only=False)
66
+
67
+ # Handle different checkpoint formats
68
+ state_dict = checkpoint
69
+ if isinstance(checkpoint, dict):
70
+ if 'model' in checkpoint:
71
+ state_dict = checkpoint['model']
72
+
73
+ # Strip common prefixes from training wrapper (ScoreRegressionLoss.deducer.*)
74
+ if isinstance(state_dict, dict):
75
+ cleaned = OrderedDict()
76
+ for key, value in state_dict.items():
77
+ new_key = key
78
+ if new_key.startswith('deducer.'):
79
+ new_key = new_key[len('deducer.'):]
80
+ cleaned[new_key] = value
81
+ # Remove non-model keys (e.g. channel_weights from Loss wrapper)
82
+ cleaned = OrderedDict((k, v) for k, v in cleaned.items()
83
+ if k.startswith('backbone.'))
84
+ state_dict = cleaned
85
+
86
+ model.load_state_dict(state_dict, strict=False)
87
+ model.eval()
88
+ model.to(device)
89
+
90
+ # Log key loading stats
91
+ model_keys = set(model.state_dict().keys())
92
+ loaded_keys = set(state_dict.keys())
93
+ matched = model_keys & loaded_keys
94
+ logging.info('GaugeService: state_dict loaded: %s (%d/%d keys matched, depth=%d, width=%d)',
95
+ resolved, len(matched), len(model_keys), unet_depth, unet_init_width)
96
+ return model
97
+
98
+
99
  class StaffGauge:
100
  """Staff gauge representation."""
101
 
 
112
  }
113
 
114
 
115
+ class GaugeService:
116
+ """Gauge prediction service. Supports TorchScript and state_dict formats."""
117
 
118
  DEFAULT_TRANS = ['Mono', 'HWC2CHW']
119
  DEFAULT_SLICING_WIDTH = 512
120
 
121
  def __init__(self, model_path, device='cuda', trans=None, slicing_width=None):
122
+ self.device = device
123
+ self.model = _load_gauge_model(model_path, device)
124
  self.composer = Composer(trans or self.DEFAULT_TRANS)
125
  self.slicing_width = slicing_width or self.DEFAULT_SLICING_WIDTH
126
 
 
151
  batch = torch.from_numpy(staves).to(self.device)
152
 
153
  # Inference
154
+ with torch.no_grad():
155
+ output = self.model(batch) # (batch, channel, height, width)
156
 
157
  # Splice output
158
  hotmap = splice_output_tensor(output, soft=True) # (channel, height, width)
backend/python-services/services/mask_service.py CHANGED
@@ -1,13 +1,18 @@
1
  """
2
  Mask prediction service.
3
  Generates staff foreground/background mask.
 
4
  """
5
 
 
 
6
  import numpy as np
7
  import torch
 
8
  import PIL.Image
9
 
10
- from predictors.torchscript_predictor import TorchScriptPredictor
 
11
  from common.image_utils import (
12
  array_from_image_stream, slice_feature, splice_output_tensor,
13
  mask_to_alpha, encode_image_base64, encode_image_bytes,
@@ -16,6 +21,65 @@ from common.image_utils import (
16
  from common.transform import Composer
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  class StaffMask:
20
  """Staff mask representation."""
21
 
@@ -32,14 +96,15 @@ class StaffMask:
32
  }
33
 
34
 
35
- class MaskService(TorchScriptPredictor):
36
- """Mask prediction service using TorchScript model."""
37
 
38
  DEFAULT_TRANS = ['Mono', 'HWC2CHW']
39
  DEFAULT_SLICING_WIDTH = 512
40
 
41
  def __init__(self, model_path, device='cuda', trans=None, slicing_width=None):
42
- super().__init__(model_path, device)
 
43
  self.composer = Composer(trans or self.DEFAULT_TRANS)
44
  self.slicing_width = slicing_width or self.DEFAULT_SLICING_WIDTH
45
 
@@ -70,7 +135,8 @@ class MaskService(TorchScriptPredictor):
70
  batch = torch.from_numpy(staves).to(self.device)
71
 
72
  # Inference
73
- output = self.run_inference(batch) # (batch, channel, height, width)
 
74
 
75
  # Splice output
76
  hotmap = splice_output_tensor(output, soft=True) # (channel, height, width)
 
1
  """
2
  Mask prediction service.
3
  Generates staff foreground/background mask.
4
+ Supports both TorchScript (.pt) and state_dict (.chkpt) model formats.
5
  """
6
 
7
+ import os
8
+ import logging
9
  import numpy as np
10
  import torch
11
+ import yaml
12
  import PIL.Image
13
 
14
+ from predictors.torchscript_predictor import resolve_model_path
15
+ from predictors.unet import UNet
16
  from common.image_utils import (
17
  array_from_image_stream, slice_feature, splice_output_tensor,
18
  mask_to_alpha, encode_image_base64, encode_image_bytes,
 
21
  from common.transform import Composer
22
 
23
 
24
+ class _ScoreWidgetsMask(torch.nn.Module):
25
+ """ScoreWidgetsMask architecture for loading .chkpt checkpoints."""
26
+
27
+ def __init__(self, in_channels=1, mask_channels=2, unet_depth=5, unet_init_width=32):
28
+ super().__init__()
29
+ self.mask = UNet(in_channels, mask_channels, depth=unet_depth, init_width=unet_init_width)
30
+
31
+ def forward(self, x):
32
+ return torch.sigmoid(self.mask(x))
33
+
34
+
35
+ def _load_mask_model(model_path, device):
36
+ """Load mask model, handling both TorchScript and state_dict formats."""
37
+ resolved = resolve_model_path(model_path)
38
+
39
+ # Try TorchScript first
40
+ try:
41
+ model = torch.jit.load(resolved, map_location=device)
42
+ model.eval()
43
+ logging.info('MaskService: TorchScript model loaded: %s', resolved)
44
+ return model
45
+ except Exception as e:
46
+ logging.info('MaskService: not TorchScript (%s), trying state_dict...', str(e)[:60])
47
+
48
+ # Read model config from .state.yaml
49
+ model_dir = os.path.dirname(resolved)
50
+ state_file = os.path.join(model_dir, '.state.yaml')
51
+ unet_depth = 5
52
+ unet_init_width = 32
53
+ if os.path.exists(state_file):
54
+ with open(state_file, 'r') as f:
55
+ state = yaml.safe_load(f)
56
+ mask_config = state.get('model', {}).get('args', {}).get('mask', {})
57
+ unet_depth = mask_config.get('unet_depth', 5)
58
+ unet_init_width = mask_config.get('unet_init_width', 32)
59
+
60
+ model = _ScoreWidgetsMask(unet_depth=unet_depth, unet_init_width=unet_init_width)
61
+ checkpoint = torch.load(resolved, map_location=device, weights_only=False)
62
+
63
+ # Handle different checkpoint formats
64
+ state_dict = checkpoint
65
+ if isinstance(checkpoint, dict):
66
+ if 'model' in checkpoint:
67
+ state_dict = checkpoint['model']
68
+
69
+ # ScoreWidgetsMask saves as {'mask': {UNet weights}}
70
+ if isinstance(state_dict, dict) and 'mask' in state_dict:
71
+ model.mask.load_state_dict(state_dict['mask'])
72
+ else:
73
+ # Try loading directly (may have 'mask.' prefix from nn.Module default)
74
+ model.load_state_dict(state_dict, strict=False)
75
+
76
+ model.eval()
77
+ model.to(device)
78
+ logging.info('MaskService: state_dict loaded: %s (depth=%d, width=%d)',
79
+ resolved, unet_depth, unet_init_width)
80
+ return model
81
+
82
+
83
  class StaffMask:
84
  """Staff mask representation."""
85
 
 
96
  }
97
 
98
 
99
+ class MaskService:
100
+ """Mask prediction service. Supports TorchScript and state_dict formats."""
101
 
102
  DEFAULT_TRANS = ['Mono', 'HWC2CHW']
103
  DEFAULT_SLICING_WIDTH = 512
104
 
105
  def __init__(self, model_path, device='cuda', trans=None, slicing_width=None):
106
+ self.device = device
107
+ self.model = _load_mask_model(model_path, device)
108
  self.composer = Composer(trans or self.DEFAULT_TRANS)
109
  self.slicing_width = slicing_width or self.DEFAULT_SLICING_WIDTH
110
 
 
135
  batch = torch.from_numpy(staves).to(self.device)
136
 
137
  # Inference
138
+ with torch.no_grad():
139
+ output = self.model(batch) # (batch, channel, height, width)
140
 
141
  # Splice output
142
  hotmap = splice_output_tensor(output, soft=True) # (channel, height, width)