BEST-RQ-2 / audio-embeddings /src /utils /lr_schedulers.py
ltuncay's picture
Submission to the Interspeech 2026 Audio Encoder Capability Challenge
eca55dc verified
import math
class LinearWarmupCosineDecay:
def __init__(
self,
warmup_steps: int,
total_steps: int,
final_lr_ratio: float,
):
self.warmup_steps = warmup_steps
self.total_steps = total_steps
self.final_lr_ratio = final_lr_ratio
def __call__(self, current_step: int) -> float:
if current_step < self.warmup_steps:
# Linear warmup
return float(current_step) / float(max(1, self.warmup_steps))
# Cosine decay
progress = float(current_step - self.warmup_steps) / float(
max(1, self.total_steps - self.warmup_steps)
)
progress = min(1.0, max(0.0, progress)) # Clip to [0, 1]
# Cosine decay from 1.0 to final_lr_ratio
# formula: final + 0.5 * (initial - final) * (1 + cos(pi * progress))
# scaled relative to initial lr (which is 1.0 in lambda)
cosine_part = 0.5 * (1.0 + math.cos(math.pi * progress))
return self.final_lr_ratio + (1.0 - self.final_lr_ratio) * cosine_part