JacobLinCool commited on
Commit
50928b1
·
verified ·
1 Parent(s): f7793f8

Upload model.py

Browse files
Files changed (1) hide show
  1. TaikoChartEstimator/model/model.py +82 -0
TaikoChartEstimator/model/model.py CHANGED
@@ -256,6 +256,88 @@ class TaikoChartEstimator(nn.Module, PyTorchModelHubMixin):
256
  instance_embeddings=instance_embeddings,
257
  )
258
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
  def predict(
260
  self,
261
  instances: torch.Tensor,
 
256
  instance_embeddings=instance_embeddings,
257
  )
258
 
259
+ def get_instance_scores(
260
+ self,
261
+ instance_embeddings: torch.Tensor,
262
+ difficulty_class_id: Optional[torch.Tensor] = None,
263
+ ) -> tuple[torch.Tensor, torch.Tensor]:
264
+ """
265
+ Estimate difficulty score for each individual instance.
266
+
267
+ This acts as a "probe": we ask the model "if the whole song consisted
268
+ only of this specific instance, what would the difficulty be?"
269
+
270
+ Args:
271
+ instance_embeddings: [batch, n_instances, d_model]
272
+ difficulty_class_id: [batch] Optional difficulty class for calibration
273
+
274
+ Returns:
275
+ raw_scores: [batch, n_instances] Unbounded raw scores
276
+ star_ratings: [batch, n_instances] Calibrated star ratings
277
+ """
278
+ batch_size, n_instances, _ = instance_embeddings.shape
279
+
280
+ # We need to pass each instance through the aggregator's fusion layer.
281
+ # The aggregator usually combines Mean, Top-K, and Branch outputs.
282
+ # For a single-instance bag:
283
+ # - Mean pooling = the instance itself
284
+ # - Top-K pooling = the instance itself
285
+ # - Branch pooling = the instance itself (weighted by 1.0)
286
+
287
+ # So we can construct the fused input directly.
288
+ # Concatenation order in MILAggregator: [mean, topk, branch_1, ..., branch_n]
289
+
290
+ # [batch, n_instances, d_instance]
291
+ feat = instance_embeddings
292
+
293
+ # Construct the concatenated feature vector for a "single-instance bag"
294
+ # We repeat the feature for: Mean (1) + TopK (1) + Branches (n_branches)
295
+ # Total repeats = 2 + n_branches
296
+ if hasattr(self.aggregator, "n_branches"):
297
+ n_repeats = 2 + self.aggregator.n_branches
298
+ # fused_input: [batch, n_instances, d_instance * n_repeats]
299
+ fused_input = feat.repeat(1, 1, n_repeats)
300
+
301
+ # Pass through fusion layer
302
+ # fusion expects [..., input_dim], so we can pass (batch * n_inst)
303
+ flat_input = fused_input.view(-1, fused_input.size(-1))
304
+ bag_embedding = self.aggregator.fusion(
305
+ flat_input
306
+ ) # [batch * n_inst, output_dim]
307
+ elif isinstance(
308
+ self.aggregator, type(self).GatedMILAggregator
309
+ ): # Check if Gated
310
+ # Gated aggregator output projection
311
+ # Gated aggregation of 1 instance is just the instance projected
312
+ flat_feat = feat.view(-1, feat.size(-1))
313
+ bag_embedding = self.aggregator.output_proj(flat_feat)
314
+ else:
315
+ # Fallback for generic/unknown aggregator
316
+ # Assume we can just run the aggregator on size-1 bags?
317
+ # But that's slow. Let's try to simulate if simple enough.
318
+ # For now, raise or return zeros if unknown.
319
+ return torch.zeros_like(feat[..., 0]), torch.zeros_like(feat[..., 0])
320
+
321
+ # Raw score head
322
+ raw_score = self.raw_score_head(bag_embedding) # [batch * n_inst, 1]
323
+ raw_score = raw_score.view(batch_size, n_instances)
324
+
325
+ # Calibration
326
+ # If no difficulty provided, predict it from the single instance
327
+ if difficulty_class_id is None:
328
+ logits = self.difficulty_classifier(bag_embedding)
329
+ diff_ids = logits.argmax(dim=-1) # [batch * n_inst]
330
+ else:
331
+ # Expand provided difficulty to per-instance
332
+ diff_ids = difficulty_class_id.unsqueeze(1).repeat(1, n_instances).view(-1)
333
+
334
+ # Calibrate
335
+ flat_raw = raw_score.view(-1)
336
+ stars = self.calibrator(flat_raw, diff_ids) # [batch * n_inst]
337
+ stars = stars.view(batch_size, n_instances)
338
+
339
+ return raw_score, stars
340
+
341
  def predict(
342
  self,
343
  instances: torch.Tensor,