#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ HF-driven inference for MAGEL with segment-level autoregressive generation. Uses from HF sample: - text instruction/template tokens (token_ids scaffold) - control tokens: chord_ids/structure_ids Does NOT use: - ground-truth audio token values as input (audio codebook positions are masked) """ import argparse import contextlib import importlib import json import os import sys from dataclasses import dataclass from datetime import datetime from pathlib import Path from typing import Any, Optional import numpy as np import torch from runtime_utils import ( load_magel_checkpoint, load_music_dataset, maybe_compile_model, maybe_mark_compile_step_begin, resolve_device, seed_everything, ) from vocab import ( CHORD_BOS_ID, CHORD_EOS_ID, STRUCTURE_EOS_ID, chord_id_to_label, structure_id_to_label, ) from modelling_qwen3 import MAGEL REPO_ROOT = Path(__file__).resolve().parent MUCODEC_ROOT = REPO_ROOT / "MuCodec" @dataclass class TokenLayout: num_text_token: int num_audio_codebook: int = 16384 @property def audio_start(self) -> int: return self.num_text_token @property def audio_end(self) -> int: return self.num_text_token + self.num_audio_codebook @property def mask_audio(self) -> int: return self.audio_end @property def bos_audio(self) -> int: return self.audio_end + 1 @property def eos_audio(self) -> int: return self.audio_end + 2 @dataclass class SegmentSpan: seg_idx: int bos_pos: int eos_pos: int audio_positions: list[int] @dataclass class HFTemplateSample: song_id: str num_text_token: int template_ids: torch.Tensor # [T], original token_ids input_ids: torch.Tensor # [T], audio codebook replaced with MASK_AUDIO chord_ids: torch.Tensor # [T] structure_ids: torch.Tensor # [T] condition_mask: torch.Tensor # [T] is_audio_codebook: torch.Tensor # [T] is_eos: torch.Tensor # [T] segments: list[SegmentSpan] raw_item: dict[str, Any] @property def seq_len(self) -> int: return int(self.input_ids.numel()) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Segment-wise AR generation from HF controls/scaffold." ) parser.add_argument( "--model_path", type=str, default="./output_qwen3_0p6b_train/final", ) parser.add_argument( "--dataset_path", type=str, default="muse_mucodec_chord.ds", ) parser.add_argument("--split", type=str, default="validation") parser.add_argument("--sample_idx", type=int, default=0) parser.add_argument( "--tokenizer_path", type=str, default="checkpoints/Qwen3-0.6B" ) parser.add_argument( "--num_audio_codebook", type=int, default=None, help="Audio codebook size. Defaults to checkpoint metadata when available.", ) parser.add_argument("--temperature", type=float, default=1.0) parser.add_argument("--top_k", type=int, default=50) parser.add_argument("--top_p", type=float, default=0.90) parser.add_argument("--greedy", action="store_true", default=False) parser.add_argument("--max_audio_tokens", type=int, default=0) parser.add_argument("--fps", type=int, default=25) parser.add_argument("--seed", type=int, default=1234) parser.add_argument("--device", type=str, default="auto") parser.add_argument( "--dtype", type=str, default="bfloat16", choices=["float32", "float16", "bfloat16"], ) parser.add_argument("--use_cache", action="store_true", default=True) parser.add_argument("--no_cache", action="store_true", default=False) parser.add_argument("--compile", action="store_true", default=False) parser.add_argument( "--compile_mode", type=str, default="reduce-overhead", choices=["default", "reduce-overhead", "max-autotune"], ) parser.add_argument( "--attn_implementation", type=str, default="sdpa", choices=["eager", "sdpa", "flash_attention_2"], ) parser.add_argument("--output_dir", type=str, default="predictions") parser.add_argument("--output_prefix", type=str, default="") parser.add_argument( "--json_output_dir", type=str, default="predictions/json", help="Directory for chord/segment json. Default: /json", ) parser.add_argument( "--mucodec_device", type=str, default="auto", help="Device string for MuCodec, for example cuda:0.", ) parser.add_argument( "--mucodec_layer_num", type=int, default=7, help="MuCodec layer_num passed to the official decoder.", ) parser.add_argument( "--mucodec_duration", type=float, default=40.96, help="Chunk duration argument passed to MuCodec code2sound.", ) parser.add_argument( "--mucodec_guidance_scale", type=float, default=1.5, help="Guidance scale argument passed to MuCodec code2sound.", ) parser.add_argument( "--mucodec_num_steps", type=int, default=20, help="Sampling steps argument passed to MuCodec code2sound.", ) parser.add_argument( "--mucodec_sample_rate", type=int, default=48000, help="Sample rate used when saving decoded wav.", ) parser.add_argument( "--wav_output_dir", type=str, default="predictions/wav", help="Directory for decoded wav. Default: /wav", ) return parser.parse_args() def resolve_runtime_device_str(device_arg: str) -> str: if device_arg != "auto": return device_arg if torch.cuda.is_available(): return "cuda:0" if torch.backends.mps.is_available(): return "mps" return "cpu" @contextlib.contextmanager def pushd(path: str): prev = os.getcwd() os.chdir(path) try: yield finally: os.chdir(prev) def ensure_sys_path(path: str) -> None: if path and path not in sys.path: sys.path.insert(0, path) def get_mucodec_root() -> str: if not MUCODEC_ROOT.is_dir(): raise FileNotFoundError(f"MuCodec directory not found: {MUCODEC_ROOT}") if not (MUCODEC_ROOT / "generate.py").is_file(): raise FileNotFoundError( f"MuCodec entrypoint not found: {MUCODEC_ROOT / 'generate.py'}" ) return str(MUCODEC_ROOT) def import_mucodec_class(): repo_path = get_mucodec_root() ensure_sys_path(repo_path) try: module = importlib.import_module("generate") return getattr(module, "MuCodec"), repo_path except Exception as exc: # pragma: no cover - env dependent raise ImportError(f"Could not import MuCodec from {repo_path}/generate.py: {exc}") def build_mucodec_decoder(args: argparse.Namespace) -> Any: MuCodec, resolved_repo = import_mucodec_class() ckpt_path = os.path.join(resolved_repo, "ckpt", "mucodec.pt") if not os.path.exists(ckpt_path): raise FileNotFoundError(f"MuCodec checkpoint not found: {ckpt_path}") required_local_files = [ os.path.join(resolved_repo, "tools", "audioldm_48k.pth"), os.path.join(resolved_repo, "muq_dev", "muq.pt"), ] for path in required_local_files: if not os.path.exists(path): raise FileNotFoundError( f"Required MuCodec dependency not found for current folder structure: {path}" ) mucodec_device = resolve_runtime_device_str(args.mucodec_device) if resolved_repo: print(f"[INFO] resolved MuCodec repo: {resolved_repo}") print(f"[INFO] loading MuCodec from {ckpt_path} on {mucodec_device}") with pushd(resolved_repo): decoder = MuCodec( model_path=ckpt_path, layer_num=int(args.mucodec_layer_num), load_main_model=True, device=mucodec_device, ) setattr(decoder, "_magel_mucodec_repo", resolved_repo) return decoder def decode_mucodec_codes( mucodec_decoder: Any, shifted_codes: np.ndarray, args: argparse.Namespace, ) -> torch.Tensor: if shifted_codes.ndim != 1: raise ValueError( f"Expected 1D MuCodec token stream, got shape {shifted_codes.shape}" ) codes = torch.from_numpy(shifted_codes.astype(np.int64, copy=False)) codes = codes.unsqueeze(0).unsqueeze(0) repo_path = getattr(mucodec_decoder, "_magel_mucodec_repo", "") decode_ctx = pushd(repo_path) if repo_path else contextlib.nullcontext() with decode_ctx: wave = mucodec_decoder.code2sound( codes, prompt=None, duration=float(args.mucodec_duration), guidance_scale=float(args.mucodec_guidance_scale), num_steps=int(args.mucodec_num_steps), disable_progress=True, ) if not torch.is_tensor(wave): wave = torch.as_tensor(wave) if wave.ndim == 1: wave = wave.unsqueeze(0) return wave.detach().cpu().to(torch.float32) def build_segment_spans( template_ids: torch.Tensor, is_audio_codebook: torch.Tensor, layout: TokenLayout, ) -> list[SegmentSpan]: bos_positions = torch.where(template_ids.eq(layout.bos_audio))[0].tolist() eos_positions = torch.where(template_ids.eq(layout.eos_audio))[0].tolist() if not bos_positions or not eos_positions: return [] spans: list[SegmentSpan] = [] eos_ptr = 0 for b in bos_positions: while eos_ptr < len(eos_positions) and eos_positions[eos_ptr] <= b: eos_ptr += 1 if eos_ptr >= len(eos_positions): break e = eos_positions[eos_ptr] eos_ptr += 1 idx = torch.arange(template_ids.numel(), device=template_ids.device) mask = is_audio_codebook & (idx > b) & (idx < e) audio_positions = torch.where(mask)[0].tolist() spans.append( SegmentSpan( seg_idx=len(spans), bos_pos=int(b), eos_pos=int(e), audio_positions=[int(p) for p in audio_positions], ) ) return spans def load_hf_template_sample( dataset_path: str, split: str, tokenizer_path: str, sample_idx: int, num_audio_codebook: int, ) -> HFTemplateSample: music_ds = load_music_dataset( dataset_path=dataset_path, split=split, tokenizer_path=tokenizer_path, num_audio_token=num_audio_codebook, use_fast=True, ) return load_hf_template_sample_from_music_dataset( music_ds=music_ds, sample_idx=sample_idx, num_audio_codebook=num_audio_codebook, ) def load_hf_template_sample_from_music_dataset( music_ds, sample_idx: int, num_audio_codebook: int, ) -> HFTemplateSample: layout = TokenLayout( num_text_token=music_ds.num_text_token, num_audio_codebook=num_audio_codebook, ) raw_item = music_ds._data[sample_idx] row = music_ds[sample_idx] template_ids = row["token_ids"].to(torch.long) chord_ids = row["chord_ids"].to(torch.long) structure_ids = row["structure_ids"].to(torch.long) condition_mask = row["condition_mask"].to(torch.bool) seq_len = int(template_ids.numel()) for name, t in [ ("chord_ids", chord_ids), ("structure_ids", structure_ids), ("condition_mask", condition_mask), ]: if int(t.numel()) != seq_len: raise ValueError(f"{name} length mismatch: {int(t.numel())} != {seq_len}") is_audio_codebook = (template_ids >= layout.audio_start) & ( template_ids < layout.audio_end ) is_eos = template_ids.eq(layout.eos_audio) # Remove GT audio token values from input scaffold. input_ids = template_ids.clone() input_ids[is_audio_codebook] = layout.mask_audio spans = build_segment_spans(template_ids, is_audio_codebook, layout) return HFTemplateSample( song_id=str(raw_item.get("song_id", f"sample_{sample_idx}")), num_text_token=music_ds.num_text_token, template_ids=template_ids, input_ids=input_ids, chord_ids=chord_ids, structure_ids=structure_ids, condition_mask=condition_mask, is_audio_codebook=is_audio_codebook, is_eos=is_eos, segments=spans, raw_item=raw_item, ) def apply_top_k_top_p(logits: torch.Tensor, top_k: int, top_p: float) -> torch.Tensor: if top_k is not None and top_k > 0: k = min(top_k, logits.shape[-1]) values, _ = torch.topk(logits, k, dim=-1) kth = values[:, -1].unsqueeze(-1) logits = logits.masked_fill(logits < kth, float("-inf")) if top_p is not None and 0.0 < top_p < 1.0: sorted_logits, sorted_idx = torch.sort(logits, descending=True, dim=-1) sorted_probs = torch.softmax(sorted_logits, dim=-1) cum_probs = torch.cumsum(sorted_probs, dim=-1) remove_mask = cum_probs > top_p remove_mask[:, 0] = False sorted_logits = sorted_logits.masked_fill(remove_mask, float("-inf")) filtered = torch.full_like(logits, float("-inf")) filtered.scatter_(dim=-1, index=sorted_idx, src=sorted_logits) logits = filtered return logits def sample_from_logits( logits: torch.Tensor, temperature: float, top_k: int, top_p: float, greedy: bool, ) -> int: if greedy or temperature <= 0: return int(torch.argmax(logits, dim=-1).item()) logits = logits / max(temperature, 1e-6) logits = apply_top_k_top_p(logits, top_k=top_k, top_p=top_p) if not torch.isfinite(logits).any(): raise RuntimeError("All logits are -inf after filtering.") probs = torch.softmax(logits, dim=-1) return int(torch.multinomial(probs, num_samples=1).item()) def sample_audio_token_from_logits( logits: torch.Tensor, layout: TokenLayout, temperature: float, top_k: int, top_p: float, greedy: bool, ) -> int: audio_logits = logits[:, layout.audio_start : layout.audio_end] sampled_audio_idx = sample_from_logits( audio_logits, temperature=temperature, top_k=top_k, top_p=top_p, greedy=greedy, ) return int(layout.audio_start + sampled_audio_idx) def chord_id_to_type(chord_id: int) -> str: decoded = chord_id_to_label(chord_id) return decoded if decoded != "N" or chord_id in {1, CHORD_BOS_ID, CHORD_EOS_ID} else f"unknown_{chord_id}" def segment_id_to_type(segment_id: int) -> str: decoded = structure_id_to_label(segment_id) return decoded if 0 <= segment_id <= STRUCTURE_EOS_ID else f"unknown_{segment_id}" def to_intervals(type_ids: list[int], fps: int, mapper) -> list[dict[str, Any]]: if not type_ids: return [] out: list[dict[str, Any]] = [] start = 0 cur = type_ids[0] for i in range(1, len(type_ids) + 1): if i == len(type_ids) or type_ids[i] != cur: out.append( { "start": round(start / float(fps), 6), "end": round(i / float(fps), 6), "type": mapper(int(cur)), } ) if i < len(type_ids): start = i cur = type_ids[i] return out def merge_same_type_with_small_gap( intervals: list[dict[str, Any]], fps: int, max_gap_frames: int = 1 ) -> list[dict[str, Any]]: if not intervals: return [] max_gap_s = float(max_gap_frames) / float(fps) merged = [dict(intervals[0])] for cur in intervals[1:]: prev = merged[-1] gap_s = float(cur["start"]) - float(prev["end"]) if prev.get("type") == cur.get("type") and gap_s <= (max_gap_s + 1e-9): prev["end"] = cur["end"] else: merged.append(dict(cur)) return merged @torch.inference_mode() def generate_segmentwise( model: MAGEL, sample: HFTemplateSample, layout: TokenLayout, device: torch.device, use_cache: bool, temperature: float, top_k: int, top_p: float, greedy: bool, max_audio_tokens: int, ) -> tuple[torch.Tensor, int, list[int], list[int]]: import time seq_template = sample.input_ids.to(device) chord_template = sample.chord_ids.to(device) structure_template = sample.structure_ids.to(device) condition_mask_template = sample.condition_mask.to(device) is_audio_code = sample.is_audio_codebook.to(device) is_eos = sample.is_eos.to(device) slot_positions = torch.where(is_audio_code | is_eos)[0] if slot_positions.numel() == 0: # No generation slot: return scaffold as-is. return seq_template.detach().cpu(), 0, [], [] start_pos = int(slot_positions[0].item()) if sample.segments: end_pos = int(sample.segments[-1].eos_pos) else: end_pos = int(slot_positions[-1].item()) sampled_chord_ids: list[int] = [] sampled_segment_ids: list[int] = [] generated_ids = seq_template.clone() sampled_count = 0 past_key_values: Optional[tuple] = None # Precompute full-sequence condition once so cached decoding keeps # the same global condition-encoder context as training. cond_template: torch.Tensor = model.condition_encoder( chord_template.unsqueeze(0), structure_template.unsqueeze(0), ) # Prefill with fixed prefix. full_attention_mask = torch.ones( (1, sample.seq_len), dtype=torch.long, device=device ) prefix_ids = generated_ids[:start_pos].unsqueeze(0) prefix_attn = full_attention_mask[:, :start_pos] model_kwargs = dict( input_ids=prefix_ids, attention_mask=prefix_attn, condition_mask=condition_mask_template[:start_pos].unsqueeze(0), cond_precomputed=cond_template[:, :start_pos, :], use_cache=use_cache, ) maybe_mark_compile_step_begin(model) prefill_t0 = time.perf_counter() out = model(**model_kwargs) prefill_time_s = time.perf_counter() - prefill_t0 logits_next = out.logits[:, -1, :] if use_cache: past_key_values = out.past_key_values step_ids = torch.empty((1, 1), dtype=torch.long, device=device) decode_time_s = 0.0 for i in range(start_pos, end_pos + 1): if bool(is_audio_code[i].item()): if max_audio_tokens > 0 and sampled_count >= max_audio_tokens: break next_id = sample_audio_token_from_logits( logits_next, layout=layout, temperature=temperature, top_k=top_k, top_p=top_p, greedy=greedy, ) sampled_count += 1 # Controls are input-aligned to the token sequence. cond_pos = i sampled_chord_ids.append(int(chord_template[cond_pos].item())) sampled_segment_ids.append(int(structure_template[cond_pos].item())) elif bool(is_eos[i].item()): next_id = layout.eos_audio else: next_id = int(seq_template[i].item()) generated_ids[i] = int(next_id) if i >= end_pos: break if use_cache: step_ids[0, 0] = int(next_id) step_attn = full_attention_mask[:, : i + 2] model_kwargs = dict( input_ids=step_ids, attention_mask=step_attn, condition_mask=condition_mask_template[i : i + 1].unsqueeze(0), cond_precomputed=cond_template[:, i : i + 1, :], past_key_values=past_key_values, use_cache=True, ) maybe_mark_compile_step_begin(model) step_t0 = time.perf_counter() out = model(**model_kwargs) decode_time_s += time.perf_counter() - step_t0 logits_next = out.logits[:, -1, :] past_key_values = out.past_key_values else: cur_len = i + 1 model_kwargs = dict( input_ids=generated_ids[:cur_len].unsqueeze(0), attention_mask=full_attention_mask[:, :cur_len], condition_mask=condition_mask_template[:cur_len].unsqueeze(0), cond_precomputed=cond_template[:, :cur_len, :], use_cache=False, ) maybe_mark_compile_step_begin(model) step_t0 = time.perf_counter() out = model(**model_kwargs) decode_time_s += time.perf_counter() - step_t0 logits_next = out.logits[:, -1, :] total_gen_time_s = prefill_time_s + decode_time_s tokens_per_second = ( float(sampled_count) / decode_time_s if decode_time_s > 0 and sampled_count > 0 else 0.0 ) print( "[PROFILE] generation " f"prefill_s={prefill_time_s:.3f} " f"decode_s={decode_time_s:.3f} " f"total_s={total_gen_time_s:.3f} " f"sampled_audio_tokens={sampled_count} " f"decode_tok_per_s={tokens_per_second:.3f}" ) return ( generated_ids.detach().cpu(), sampled_count, sampled_chord_ids, sampled_segment_ids, ) @torch.inference_mode() def batch_generate_segmentwise( model: MAGEL, samples: list[HFTemplateSample], layout: TokenLayout, device: torch.device, use_cache: bool, temperature: float, top_k: int, top_p: float, greedy: bool, max_audio_tokens: int, ) -> list[tuple[torch.Tensor, int, list[int], list[int]]]: import time if not samples: return [] if not use_cache: return [ generate_segmentwise( model=model, sample=sample, layout=layout, device=device, use_cache=use_cache, temperature=temperature, top_k=top_k, top_p=top_p, greedy=greedy, max_audio_tokens=max_audio_tokens, ) for sample in samples ] batch_size = len(samples) seq_lens = [sample.seq_len for sample in samples] max_seq_len = max(seq_lens) seq_templates = torch.zeros((batch_size, max_seq_len), dtype=torch.long, device=device) generated_ids = torch.zeros((batch_size, max_seq_len), dtype=torch.long, device=device) chord_templates = torch.zeros((batch_size, max_seq_len), dtype=torch.long, device=device) structure_templates = torch.zeros((batch_size, max_seq_len), dtype=torch.long, device=device) condition_mask_templates = torch.zeros( (batch_size, max_seq_len), dtype=torch.bool, device=device ) is_audio_code_templates = torch.zeros( (batch_size, max_seq_len), dtype=torch.bool, device=device ) is_eos_templates = torch.zeros((batch_size, max_seq_len), dtype=torch.bool, device=device) start_positions: list[int] = [] end_positions: list[int] = [] sampled_counts = [0 for _ in samples] sampled_chord_ids: list[list[int]] = [[] for _ in samples] sampled_segment_ids: list[list[int]] = [[] for _ in samples] valid_sample_mask = torch.ones(batch_size, dtype=torch.bool, device=device) for row_idx, sample in enumerate(samples): seq_templates[row_idx, : sample.seq_len] = sample.input_ids.to(device) generated_ids[row_idx, : sample.seq_len] = sample.input_ids.to(device) chord_templates[row_idx, : sample.seq_len] = sample.chord_ids.to(device) structure_templates[row_idx, : sample.seq_len] = sample.structure_ids.to(device) condition_mask_templates[row_idx, : sample.seq_len] = sample.condition_mask.to(device) is_audio_code_templates[row_idx, : sample.seq_len] = sample.is_audio_codebook.to(device) is_eos_templates[row_idx, : sample.seq_len] = sample.is_eos.to(device) slot_positions = torch.where( is_audio_code_templates[row_idx, : sample.seq_len] | is_eos_templates[row_idx, : sample.seq_len] )[0] if slot_positions.numel() == 0: valid_sample_mask[row_idx] = False start_positions.append(sample.seq_len) end_positions.append(sample.seq_len - 1) continue start_pos = int(slot_positions[0].item()) if sample.segments: end_pos = int(sample.segments[-1].eos_pos) else: end_pos = int(slot_positions[-1].item()) start_positions.append(start_pos) end_positions.append(end_pos) if not bool(valid_sample_mask.any().item()): return [ (sample.input_ids.detach().cpu(), 0, [], []) for sample in samples ] start_positions_t = torch.tensor(start_positions, dtype=torch.long, device=device) end_positions_t = torch.tensor(end_positions, dtype=torch.long, device=device) prefix_lens = start_positions_t.clone() max_prefix_len = int(prefix_lens.max().item()) max_decode_steps = int((end_positions_t - start_positions_t + 1).clamp_min(0).max().item()) cond_template = model.condition_encoder(chord_templates, structure_templates) prefix_attention_mask = ( torch.arange(max_prefix_len, device=device).unsqueeze(0) < prefix_lens.unsqueeze(1) ).to(torch.long) prefill_t0 = time.perf_counter() maybe_mark_compile_step_begin(model) out = model( input_ids=generated_ids[:, :max_prefix_len], attention_mask=prefix_attention_mask, condition_mask=condition_mask_templates[:, :max_prefix_len], cond_precomputed=cond_template[:, :max_prefix_len, :], use_cache=True, ) prefill_time_s = time.perf_counter() - prefill_t0 gather_idx = (prefix_lens - 1).clamp_min(0) batch_indices = torch.arange(batch_size, device=device) logits_next = out.logits[batch_indices, gather_idx, :] past_key_values = out.past_key_values step_ids = torch.zeros((batch_size, 1), dtype=torch.long, device=device) decode_valid_mask = torch.zeros( (batch_size, max_decode_steps), dtype=torch.bool, device=device ) decode_time_s = 0.0 for step_idx in range(max_decode_steps): cur_positions = start_positions_t + step_idx active_mask = valid_sample_mask & cur_positions.le(end_positions_t) if not bool(active_mask.any().item()): break next_ids = torch.zeros(batch_size, dtype=torch.long, device=device) for row_idx in range(batch_size): if not bool(active_mask[row_idx].item()): continue cur_pos = int(cur_positions[row_idx].item()) if bool(is_audio_code_templates[row_idx, cur_pos].item()): if max_audio_tokens > 0 and sampled_counts[row_idx] >= max_audio_tokens: valid_sample_mask[row_idx] = False continue next_id = sample_audio_token_from_logits( logits_next[row_idx : row_idx + 1], layout=layout, temperature=temperature, top_k=top_k, top_p=top_p, greedy=greedy, ) sampled_counts[row_idx] += 1 sampled_chord_ids[row_idx].append( int(chord_templates[row_idx, cur_pos].item()) ) sampled_segment_ids[row_idx].append( int(structure_templates[row_idx, cur_pos].item()) ) elif bool(is_eos_templates[row_idx, cur_pos].item()): next_id = layout.eos_audio else: next_id = int(seq_templates[row_idx, cur_pos].item()) generated_ids[row_idx, cur_pos] = int(next_id) next_ids[row_idx] = int(next_id) decode_valid_mask[row_idx, step_idx] = True if step_idx >= max_decode_steps - 1: break step_ids[:, 0] = next_ids step_attention_mask = torch.cat( [ prefix_attention_mask, decode_valid_mask[:, : step_idx + 1].to(torch.long), ], dim=1, ) step_condition_mask = torch.zeros((batch_size, 1), dtype=torch.bool, device=device) step_cond = torch.zeros( (batch_size, 1, cond_template.shape[-1]), dtype=cond_template.dtype, device=device, ) for row_idx in range(batch_size): if not bool(decode_valid_mask[row_idx, step_idx].item()): continue cur_pos = int(cur_positions[row_idx].item()) step_condition_mask[row_idx, 0] = condition_mask_templates[row_idx, cur_pos] step_cond[row_idx, 0, :] = cond_template[row_idx, cur_pos, :] step_t0 = time.perf_counter() maybe_mark_compile_step_begin(model) out = model( input_ids=step_ids, attention_mask=step_attention_mask, condition_mask=step_condition_mask, cond_precomputed=step_cond, past_key_values=past_key_values, use_cache=True, ) decode_time_s += time.perf_counter() - step_t0 logits_next = out.logits[:, -1, :] past_key_values = out.past_key_values total_sampled_tokens = sum(sampled_counts) total_gen_time_s = prefill_time_s + decode_time_s tokens_per_second = ( float(total_sampled_tokens) / decode_time_s if decode_time_s > 0 and total_sampled_tokens > 0 else 0.0 ) print( "[PROFILE] batch_generation " f"batch_size={batch_size} " f"prefill_s={prefill_time_s:.3f} " f"decode_s={decode_time_s:.3f} " f"total_s={total_gen_time_s:.3f} " f"sampled_audio_tokens={total_sampled_tokens} " f"decode_tok_per_s={tokens_per_second:.3f}" ) outputs: list[tuple[torch.Tensor, int, list[int], list[int]]] = [] for row_idx, sample in enumerate(samples): if not bool((torch.where(sample.is_audio_codebook | sample.is_eos)[0]).numel()): outputs.append((sample.input_ids.detach().cpu(), 0, [], [])) continue outputs.append( ( generated_ids[row_idx, : sample.seq_len].detach().cpu(), sampled_counts[row_idx], sampled_chord_ids[row_idx], sampled_segment_ids[row_idx], ) ) return outputs def save_outputs( output_dir: str, output_prefix: str, sample: HFTemplateSample, layout: TokenLayout, generated_ids: torch.Tensor, sampled_chord_ids: list[int], sampled_segment_ids: list[int], args: argparse.Namespace, mucodec_decoder: Any = None, ) -> None: import time Path(output_dir).mkdir(parents=True, exist_ok=True) stamp = datetime.now().strftime("%Y%m%d_%H%M%S") prefix = output_prefix or f"{sample.song_id}_{args.sample_idx}_{stamp}" json_dir = args.json_output_dir or os.path.join(output_dir, "json") wav_dir = args.wav_output_dir or os.path.join(output_dir, "wav") Path(json_dir).mkdir(parents=True, exist_ok=True) Path(wav_dir).mkdir(parents=True, exist_ok=True) json_path = os.path.join(json_dir, f"{prefix}.chord_segment.json") wav_path = os.path.join(wav_dir, f"{prefix}.wav") gen_full = generated_ids.cpu().numpy().astype(np.int64) gen_audio_raw = gen_full[ (gen_full >= layout.audio_start) & (gen_full < layout.audio_end) ] gen_audio_shift = gen_audio_raw - layout.audio_start save_t0 = time.perf_counter() if gen_audio_shift.size == 0: print("[WARN] No generated MuCodec tokens; skipping wav decode.") else: import torchaudio wave = decode_mucodec_codes(mucodec_decoder, gen_audio_shift, args) torchaudio.save(wav_path, wave, int(args.mucodec_sample_rate)) print(f"[OK] {wav_path}") chord_intervals = to_intervals( sampled_chord_ids, fps=int(args.fps), mapper=chord_id_to_type ) segment_intervals = to_intervals( sampled_segment_ids, fps=int(args.fps), mapper=segment_id_to_type ) # PAD is used for EOS-related conditioning; drop it in exported json. chord_intervals = [x for x in chord_intervals if x.get("type") != "pad"] segment_intervals = [x for x in segment_intervals if x.get("type") != "pad"] chord_intervals = merge_same_type_with_small_gap( chord_intervals, fps=int(args.fps), max_gap_frames=1 ) segment_intervals = merge_same_type_with_small_gap( segment_intervals, fps=int(args.fps), max_gap_frames=1 ) chord_segment = { "song_id": sample.song_id, "sample_idx": int(args.sample_idx), "fps": int(args.fps), "generated_audio_count": int(gen_audio_raw.shape[0]), "chord": chord_intervals, "segment": segment_intervals, } with open(json_path, "w", encoding="utf-8") as f: json.dump(chord_segment, f, ensure_ascii=False, indent=2) print(f"[OK] {json_path}") save_time_s = time.perf_counter() - save_t0 print( "[PROFILE] save " f"save_s={save_time_s:.3f} " f"generated_audio_count={int(gen_audio_raw.shape[0])}" ) def main() -> None: import time args = parse_args() seed_everything(args.seed) use_cache = args.use_cache and not args.no_cache device = resolve_device(args.device) dtype = { "float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16, }[args.dtype] if device.type == "cpu" and dtype != torch.float32: print(f"[WARN] dtype {dtype} on CPU may be unsupported; fallback to float32.") dtype = torch.float32 print(f"[INFO] device={device}, dtype={dtype}, use_cache={use_cache}") print(f"[INFO] loading model from {args.model_path}") model = load_magel_checkpoint( checkpoint_path=args.model_path, device=device, dtype=dtype, attn_implementation=args.attn_implementation, ) model = maybe_compile_model( model, enabled=bool(args.compile), mode=str(args.compile_mode), ) num_audio_codebook = ( int(args.num_audio_codebook) if args.num_audio_codebook is not None else int(getattr(model.config, "magel_num_audio_token", 16384)) ) print(f"[INFO] num_audio_codebook={num_audio_codebook}") print(f"[INFO] loading HF sample idx={args.sample_idx} from {args.dataset_path}") sample = load_hf_template_sample( dataset_path=args.dataset_path, split=args.split, tokenizer_path=args.tokenizer_path, sample_idx=args.sample_idx, num_audio_codebook=num_audio_codebook, ) layout = TokenLayout( num_text_token=sample.num_text_token, num_audio_codebook=num_audio_codebook, ) print( f"[INFO] song_id={sample.song_id}, seq_len={sample.seq_len}, segments={len(sample.segments)}" ) mucodec_decoder = build_mucodec_decoder(args) print("[INFO] running segment-level autoregressive generation...") t1 = time.time() ( generated_ids, sampled_count, sampled_chord_ids, sampled_segment_ids, ) = generate_segmentwise( model=model, sample=sample, layout=layout, device=device, use_cache=use_cache, temperature=float(args.temperature), top_k=int(args.top_k), top_p=float(args.top_p), greedy=bool(args.greedy), max_audio_tokens=max(0, int(args.max_audio_tokens)), ) print(f"[INFO] sampled audio tokens: {sampled_count}") print(f"[INFO] output sequence length: {generated_ids.numel()}") t2 = time.time() print("total time:", t2 - t1) save_outputs( output_dir=args.output_dir, output_prefix=args.output_prefix, sample=sample, layout=layout, generated_ids=generated_ids, sampled_chord_ids=sampled_chord_ids, sampled_segment_ids=sampled_segment_ids, args=args, mucodec_decoder=mucodec_decoder, ) if __name__ == "__main__": main()