ChuxiJ commited on
Commit
0659e3b
·
1 Parent(s): a161649

support test time scaling & auto score & next batch

Browse files
acestep/gradio_ui.py CHANGED
The diff for this file is too large to render. See raw diff
 
acestep/llm_inference.py CHANGED
@@ -337,6 +337,155 @@ class LLMHandler:
337
  output_text = str(outputs)
338
 
339
  return output_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
 
341
  def _run_pt_from_formatted(
342
  self,
@@ -573,6 +722,8 @@ class LLMHandler:
573
  use_cot_caption: Whether to generate caption in CoT (default True).
574
  use_cot_language: Whether to generate language in CoT (default True).
575
  """
 
 
576
  infer_type = (infer_type or "").strip().lower()
577
  if infer_type not in {"dit", "llm_dit"}:
578
  return {}, "", f"❌ invalid infer_type: {infer_type!r} (expected 'dit' or 'llm_dit')"
@@ -581,10 +732,15 @@ class LLMHandler:
581
  audio_codes = ""
582
  has_all_metas = self.has_all_metas(user_metadata)
583
 
 
 
 
 
584
  # ========== PHASE 1: CoT Generation ==========
585
  # Always generate CoT unless all metadata are user-provided
586
  if not has_all_metas or not is_format_caption:
587
  logger.info("Phase 1: Generating CoT metadata...")
 
588
 
589
  # Build formatted prompt for CoT phase
590
  formatted_prompt = self.build_formatted_prompt(caption, lyrics, generation_phase="cot")
@@ -615,12 +771,14 @@ class LLMHandler:
615
  stop_at_reasoning=True, # Always stop at </think> in Phase 1
616
  )
617
 
 
 
618
  if not cot_output_text:
619
  return {}, "", status
620
 
621
  # Parse metadata from CoT output
622
  metadata, _ = self.parse_lm_output(cot_output_text)
623
- logger.info(f"Phase 1 completed. Generated metadata: {list(metadata.keys())}")
624
  else:
625
  # Use user-provided metadata
626
  logger.info("Phase 1: Using user-provided metadata (skipping generation)")
@@ -628,11 +786,12 @@ class LLMHandler:
628
 
629
  # If infer_type is 'dit', stop here and return only metadata
630
  if infer_type == "dit":
631
- status_msg = f"✅ Generated CoT metadata successfully\nFields: {', '.join(metadata.keys())}"
632
  return metadata, "", status_msg
633
 
634
  # ========== PHASE 2: Audio Codes Generation ==========
635
  logger.info("Phase 2: Generating audio codes...")
 
636
 
637
  # Format metadata as CoT using YAML (matching training format)
638
  cot_text = self._format_metadata_as_cot(metadata)
@@ -668,14 +827,192 @@ class LLMHandler:
668
  if not codes_output_text:
669
  return metadata, "", status
670
 
 
 
671
  # Parse audio codes from output (metadata should be same as Phase 1)
672
  _, audio_codes = self.parse_lm_output(codes_output_text)
673
 
674
  codes_count = len(audio_codes.split('<|audio_code_')) - 1 if audio_codes else 0
675
- logger.info(f"Phase 2 completed. Generated {codes_count} audio codes")
676
 
677
- status_msg = f"✅ Generated successfully (2-phase)\nPhase 1: CoT metadata\nPhase 2: {codes_count} audio codes"
678
  return metadata, audio_codes, status_msg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
679
 
680
  def build_formatted_prompt(self, caption: str, lyrics: str = "", is_negative_prompt: bool = False, generation_phase: str = "cot", negative_prompt: str = "NO USER INPUT") -> str:
681
  """
 
337
  output_text = str(outputs)
338
 
339
  return output_text
340
+
341
+ def _run_vllm_batch(
342
+ self,
343
+ formatted_prompts: List[str],
344
+ temperature: float,
345
+ cfg_scale: float,
346
+ negative_prompt: str,
347
+ top_k: Optional[int],
348
+ top_p: Optional[float],
349
+ repetition_penalty: float,
350
+ use_constrained_decoding: bool = True,
351
+ constrained_decoding_debug: bool = False,
352
+ target_duration: Optional[float] = None,
353
+ generation_phase: str = "codes",
354
+ caption: str = "",
355
+ lyrics: str = "",
356
+ cot_text: str = "",
357
+ seeds: Optional[List[int]] = None,
358
+ ) -> List[str]:
359
+ """Batch generation using vllm backend"""
360
+ from nanovllm import SamplingParams
361
+
362
+ batch_size = len(formatted_prompts)
363
+
364
+ # Determine effective temperature for sampler
365
+ effective_sampler_temp = temperature
366
+
367
+ # Use shared constrained processor if enabled
368
+ # Note: vllm batch mode uses same processor for all items
369
+ constrained_processor = None
370
+ if use_constrained_decoding:
371
+ # Reset processor state for new generation
372
+ self.constrained_processor.reset()
373
+
374
+ self.constrained_processor.enabled = use_constrained_decoding
375
+ self.constrained_processor.debug = constrained_decoding_debug
376
+ self.constrained_processor.metadata_temperature = None
377
+ self.constrained_processor.codes_temperature = None
378
+ self.constrained_processor.set_target_duration(target_duration)
379
+ self.constrained_processor.set_user_metadata(None)
380
+ self.constrained_processor.set_stop_at_reasoning(False)
381
+ self.constrained_processor.set_skip_genres(True)
382
+ self.constrained_processor.set_skip_caption(True)
383
+ self.constrained_processor.set_skip_language(True)
384
+ self.constrained_processor.set_generation_phase(generation_phase)
385
+
386
+ constrained_processor = self.constrained_processor
387
+
388
+ # Build sampling params
389
+ sampling_params = SamplingParams(
390
+ max_tokens=self.max_model_len - 64,
391
+ temperature=effective_sampler_temp,
392
+ cfg_scale=cfg_scale,
393
+ top_k=top_k,
394
+ top_p=top_p,
395
+ repetition_penalty=repetition_penalty,
396
+ logits_processor=constrained_processor,
397
+ logits_processor_update_state=constrained_processor.update_state if constrained_processor else None,
398
+ )
399
+
400
+ # Generate with or without CFG
401
+ if cfg_scale > 1.0:
402
+ # Build unconditional prompts
403
+ formatted_unconditional_prompt = self.build_formatted_prompt_with_cot(
404
+ caption, lyrics, cot_text, is_negative_prompt=True, negative_prompt=negative_prompt
405
+ )
406
+ unconditional_prompts = [formatted_unconditional_prompt] * batch_size
407
+
408
+ outputs = self.llm.generate(
409
+ formatted_prompts,
410
+ sampling_params,
411
+ unconditional_prompts=unconditional_prompts,
412
+ )
413
+ else:
414
+ outputs = self.llm.generate(formatted_prompts, sampling_params)
415
+
416
+ # Extract text from each output
417
+ output_texts = []
418
+ for output in outputs:
419
+ if hasattr(output, "outputs") and len(output.outputs) > 0:
420
+ output_texts.append(output.outputs[0].text)
421
+ elif hasattr(output, "text"):
422
+ output_texts.append(output.text)
423
+ elif isinstance(output, dict) and "text" in output:
424
+ output_texts.append(output["text"])
425
+ else:
426
+ output_texts.append(str(output))
427
+
428
+ return output_texts
429
+
430
+ def _run_pt_batch(
431
+ self,
432
+ formatted_prompts: List[str],
433
+ temperature: float,
434
+ cfg_scale: float,
435
+ negative_prompt: str,
436
+ top_k: Optional[int],
437
+ top_p: Optional[float],
438
+ repetition_penalty: float,
439
+ use_constrained_decoding: bool = True,
440
+ constrained_decoding_debug: bool = False,
441
+ target_duration: Optional[float] = None,
442
+ generation_phase: str = "codes",
443
+ caption: str = "",
444
+ lyrics: str = "",
445
+ cot_text: str = "",
446
+ seeds: Optional[List[int]] = None,
447
+ ) -> List[str]:
448
+ """Batch generation using PyTorch backend"""
449
+ import random
450
+
451
+ batch_size = len(formatted_prompts)
452
+ output_texts = []
453
+
454
+ # Generate each item sequentially with different seeds
455
+ # (PyTorch backend doesn't support true batching efficiently)
456
+ for i, formatted_prompt in enumerate(formatted_prompts):
457
+ # Set seed for this item if provided
458
+ if seeds and i < len(seeds):
459
+ torch.manual_seed(seeds[i])
460
+ if torch.cuda.is_available():
461
+ torch.cuda.manual_seed_all(seeds[i])
462
+
463
+ # Generate using single-item method
464
+ output_text = self._run_pt_from_formatted(
465
+ formatted_prompt=formatted_prompt,
466
+ temperature=temperature,
467
+ cfg_scale=cfg_scale,
468
+ negative_prompt=negative_prompt,
469
+ top_k=top_k,
470
+ top_p=top_p,
471
+ repetition_penalty=repetition_penalty,
472
+ use_constrained_decoding=use_constrained_decoding,
473
+ constrained_decoding_debug=constrained_decoding_debug,
474
+ target_duration=target_duration,
475
+ user_metadata=None,
476
+ stop_at_reasoning=False,
477
+ skip_genres=True,
478
+ skip_caption=True,
479
+ skip_language=True,
480
+ generation_phase=generation_phase,
481
+ caption=caption,
482
+ lyrics=lyrics,
483
+ cot_text=cot_text,
484
+ )
485
+
486
+ output_texts.append(output_text)
487
+
488
+ return output_texts
489
 
490
  def _run_pt_from_formatted(
491
  self,
 
722
  use_cot_caption: Whether to generate caption in CoT (default True).
723
  use_cot_language: Whether to generate language in CoT (default True).
724
  """
725
+ import time
726
+
727
  infer_type = (infer_type or "").strip().lower()
728
  if infer_type not in {"dit", "llm_dit"}:
729
  return {}, "", f"❌ invalid infer_type: {infer_type!r} (expected 'dit' or 'llm_dit')"
 
732
  audio_codes = ""
733
  has_all_metas = self.has_all_metas(user_metadata)
734
 
735
+ # Timing variables
736
+ phase1_time = 0.0
737
+ phase2_time = 0.0
738
+
739
  # ========== PHASE 1: CoT Generation ==========
740
  # Always generate CoT unless all metadata are user-provided
741
  if not has_all_metas or not is_format_caption:
742
  logger.info("Phase 1: Generating CoT metadata...")
743
+ phase1_start = time.time()
744
 
745
  # Build formatted prompt for CoT phase
746
  formatted_prompt = self.build_formatted_prompt(caption, lyrics, generation_phase="cot")
 
771
  stop_at_reasoning=True, # Always stop at </think> in Phase 1
772
  )
773
 
774
+ phase1_time = time.time() - phase1_start
775
+
776
  if not cot_output_text:
777
  return {}, "", status
778
 
779
  # Parse metadata from CoT output
780
  metadata, _ = self.parse_lm_output(cot_output_text)
781
+ logger.info(f"Phase 1 completed in {phase1_time:.2f}s. Generated metadata: {list(metadata.keys())}")
782
  else:
783
  # Use user-provided metadata
784
  logger.info("Phase 1: Using user-provided metadata (skipping generation)")
 
786
 
787
  # If infer_type is 'dit', stop here and return only metadata
788
  if infer_type == "dit":
789
+ status_msg = f"✅ Generated CoT metadata successfully\nFields: {', '.join(metadata.keys())}\nPhase1: {phase1_time:.2f}s"
790
  return metadata, "", status_msg
791
 
792
  # ========== PHASE 2: Audio Codes Generation ==========
793
  logger.info("Phase 2: Generating audio codes...")
794
+ phase2_start = time.time()
795
 
796
  # Format metadata as CoT using YAML (matching training format)
797
  cot_text = self._format_metadata_as_cot(metadata)
 
827
  if not codes_output_text:
828
  return metadata, "", status
829
 
830
+ phase2_time = time.time() - phase2_start
831
+
832
  # Parse audio codes from output (metadata should be same as Phase 1)
833
  _, audio_codes = self.parse_lm_output(codes_output_text)
834
 
835
  codes_count = len(audio_codes.split('<|audio_code_')) - 1 if audio_codes else 0
836
+ logger.info(f"Phase 2 completed in {phase2_time:.2f}s. Generated {codes_count} audio codes")
837
 
838
+ status_msg = f"✅ Generated successfully (2-phase)\nPhase 1: CoT metadata\nPhase 2: {codes_count} audio codes\nPhase1: {phase1_time:.2f}s, Phase2: {phase2_time:.2f}s"
839
  return metadata, audio_codes, status_msg
840
+
841
+ def generate_with_stop_condition_batch(
842
+ self,
843
+ caption: str,
844
+ lyrics: str,
845
+ batch_size: int,
846
+ infer_type: str = "llm_dit",
847
+ temperature: float = 0.85,
848
+ cfg_scale: float = 1.0,
849
+ negative_prompt: str = "NO USER INPUT",
850
+ top_k: Optional[int] = None,
851
+ top_p: Optional[float] = None,
852
+ repetition_penalty: float = 1.0,
853
+ use_constrained_decoding: bool = True,
854
+ constrained_decoding_debug: bool = False,
855
+ target_duration: Optional[float] = None,
856
+ user_metadata: Optional[Dict[str, Optional[str]]] = None,
857
+ use_cot_caption: bool = True,
858
+ use_cot_language: bool = True,
859
+ is_format_caption: bool = False,
860
+ seeds: Optional[List[int]] = None,
861
+ ) -> Tuple[List[Dict[str, Any]], List[str], str]:
862
+ """
863
+ Batch version of generate_with_stop_condition.
864
+
865
+ Generates multiple audio codes with same conditions but different seeds (for diversity).
866
+
867
+ Args:
868
+ caption: Same caption for all items
869
+ lyrics: Same lyrics for all items
870
+ batch_size: Number of items to generate
871
+ seeds: Optional list of seeds for each batch item (for reproducibility)
872
+ ... (other args same as generate_with_stop_condition)
873
+
874
+ Returns:
875
+ Tuple of (metadata_list, audio_codes_list, status_message)
876
+ - metadata_list: List of metadata dicts (same metadata for all items)
877
+ - audio_codes_list: List of audio code strings (one per item, different due to sampling)
878
+ - status_message: Generation status
879
+ """
880
+ import random
881
+ import time
882
+
883
+ infer_type = (infer_type or "").strip().lower()
884
+ if infer_type not in {"dit", "llm_dit"}:
885
+ return [], [], f"❌ invalid infer_type: {infer_type!r} (expected 'dit' or 'llm_dit')"
886
+
887
+ # Generate seeds if not provided
888
+ if seeds is None:
889
+ seeds = [random.randint(0, 2**32 - 1) for _ in range(batch_size)]
890
+ elif len(seeds) < batch_size:
891
+ # Pad with random seeds if not enough provided
892
+ seeds = list(seeds) + [random.randint(0, 2**32 - 1) for _ in range(batch_size - len(seeds))]
893
+ else:
894
+ seeds = seeds[:batch_size] # Truncate if too many
895
+
896
+ # Timing variables
897
+ phase1_time = 0.0
898
+ phase2_time = 0.0
899
+
900
+ # ========== PHASE 1: CoT Generation (ONCE for all items) ==========
901
+ has_all_metas = self.has_all_metas(user_metadata)
902
+
903
+ if not has_all_metas or not is_format_caption:
904
+ logger.info("Batch Phase 1: Generating CoT metadata (once for all items)...")
905
+ phase1_start = time.time()
906
+
907
+ # Generate CoT metadata once (same for all batch items)
908
+ metadata, _, status = self.generate_with_stop_condition(
909
+ caption=caption,
910
+ lyrics=lyrics,
911
+ infer_type="dit", # Only generate metadata
912
+ temperature=temperature,
913
+ cfg_scale=cfg_scale,
914
+ negative_prompt=negative_prompt,
915
+ top_k=top_k,
916
+ top_p=top_p,
917
+ repetition_penalty=repetition_penalty,
918
+ use_constrained_decoding=use_constrained_decoding,
919
+ constrained_decoding_debug=constrained_decoding_debug,
920
+ target_duration=target_duration,
921
+ user_metadata=user_metadata,
922
+ use_cot_caption=use_cot_caption,
923
+ use_cot_language=use_cot_language,
924
+ is_format_caption=is_format_caption,
925
+ )
926
+
927
+ phase1_time = time.time() - phase1_start
928
+
929
+ if not metadata:
930
+ return [], [], status
931
+
932
+ logger.info(f"Batch Phase 1 completed in {phase1_time:.2f}s. Generated metadata: {list(metadata.keys())}")
933
+ else:
934
+ # Use user-provided metadata
935
+ logger.info("Batch Phase 1: Using user-provided metadata (skipping generation)")
936
+ metadata = {k: v for k, v in user_metadata.items() if v is not None}
937
+
938
+ # If infer_type is 'dit', stop here and return only metadata
939
+ if infer_type == "dit":
940
+ metadata_list = [metadata.copy() for _ in range(batch_size)]
941
+ status_msg = f"✅ Generated CoT metadata successfully (batch mode)\nFields: {', '.join(metadata.keys())}\nPhase1: {phase1_time:.2f}s"
942
+ return metadata_list, [""] * batch_size, status_msg
943
+
944
+ # ========== PHASE 2: Audio Codes Generation (BATCH) ==========
945
+ logger.info(f"Batch Phase 2: Generating audio codes for {batch_size} items...")
946
+ phase2_start = time.time()
947
+
948
+ # Format metadata as CoT
949
+ cot_text = self._format_metadata_as_cot(metadata)
950
+
951
+ # Build formatted prompt with CoT
952
+ formatted_prompt = self.build_formatted_prompt_with_cot(caption, lyrics, cot_text)
953
+
954
+ # Replicate prompt for batch (all items have same prompt, differ by seeds)
955
+ formatted_prompts = [formatted_prompt] * batch_size
956
+
957
+ # Call backend-specific batch generation
958
+ try:
959
+ if self.llm_backend == "vllm":
960
+ codes_outputs = self._run_vllm_batch(
961
+ formatted_prompts=formatted_prompts,
962
+ temperature=temperature,
963
+ cfg_scale=cfg_scale,
964
+ negative_prompt=negative_prompt,
965
+ top_k=top_k,
966
+ top_p=top_p,
967
+ repetition_penalty=repetition_penalty,
968
+ use_constrained_decoding=use_constrained_decoding,
969
+ constrained_decoding_debug=constrained_decoding_debug,
970
+ target_duration=target_duration,
971
+ generation_phase="codes",
972
+ caption=caption,
973
+ lyrics=lyrics,
974
+ cot_text=cot_text,
975
+ seeds=seeds,
976
+ )
977
+ else: # pt backend
978
+ codes_outputs = self._run_pt_batch(
979
+ formatted_prompts=formatted_prompts,
980
+ temperature=temperature,
981
+ cfg_scale=cfg_scale,
982
+ negative_prompt=negative_prompt,
983
+ top_k=top_k,
984
+ top_p=top_p,
985
+ repetition_penalty=repetition_penalty,
986
+ use_constrained_decoding=use_constrained_decoding,
987
+ constrained_decoding_debug=constrained_decoding_debug,
988
+ target_duration=target_duration,
989
+ generation_phase="codes",
990
+ caption=caption,
991
+ lyrics=lyrics,
992
+ cot_text=cot_text,
993
+ seeds=seeds,
994
+ )
995
+ except Exception as e:
996
+ error_msg = f"❌ Error in batch codes generation: {str(e)}"
997
+ logger.error(error_msg)
998
+ return [], [], error_msg
999
+
1000
+ # Parse audio codes from each output
1001
+ audio_codes_list = []
1002
+ metadata_list = []
1003
+ for output_text in codes_outputs:
1004
+ _, audio_codes = self.parse_lm_output(output_text)
1005
+ audio_codes_list.append(audio_codes)
1006
+ metadata_list.append(metadata.copy()) # Same metadata for all
1007
+
1008
+ phase2_time = time.time() - phase2_start
1009
+
1010
+ # Log results
1011
+ codes_counts = [len(codes.split('<|audio_code_')) - 1 if codes else 0 for codes in audio_codes_list]
1012
+ logger.info(f"Batch Phase 2 completed in {phase2_time:.2f}s. Generated codes: {codes_counts}")
1013
+
1014
+ status_msg = f"✅ Batch generation completed ({batch_size} items)\nPhase 1: CoT metadata\nPhase 2: {sum(codes_counts)} total codes ({codes_counts})\nPhase1: {phase1_time:.2f}s, Phase2: {phase2_time:.2f}s"
1015
+ return metadata_list, audio_codes_list, status_msg
1016
 
1017
  def build_formatted_prompt(self, caption: str, lyrics: str = "", is_negative_prompt: bool = False, generation_phase: str = "cot", negative_prompt: str = "NO USER INPUT") -> str:
1018
  """
acestep/test_time_scaling.py CHANGED
@@ -4,258 +4,316 @@ Implements perplexity-based scoring for generated audio codes
4
  """
5
  import torch
6
  import torch.nn.functional as F
7
- from typing import Tuple, Optional, Dict, Any
8
  from loguru import logger
9
  import yaml
 
 
10
 
11
 
12
- def perplexity_to_score(perplexity: float, scale: float = 100.0) -> float:
13
  """
14
- Convert perplexity to a normalized score in [0, 1] range.
15
 
16
- Lower perplexity = higher score (better quality)
17
- Uses exponential decay: score = exp(-perplexity / scale)
 
 
 
18
 
19
  Args:
20
- perplexity: Perplexity value (typically 1 to 1000+)
21
- scale: Scale parameter to control score distribution (default 100.0)
22
- - Smaller scale: more sensitive to perplexity changes
23
- - Larger scale: less sensitive to perplexity changes
24
 
25
  Returns:
26
- Score in [0, 1] range, where 1 is perfect and 0 is worst
27
-
28
- Examples:
29
- perplexity=1 score≈0.99 (excellent)
30
- perplexity=50 → score≈0.61 (good if scale=100)
31
- perplexity=100 → score≈0.37 (medium if scale=100)
32
- perplexity=500 → score≈0.01 (poor if scale=100)
33
  """
34
- import math
35
- return math.exp(-perplexity / scale)
36
 
37
 
38
- def calculate_perplexity(
39
- llm_handler,
40
- audio_codes: str,
41
- caption: str = "",
42
- lyrics: str = "",
43
- metadata: Optional[Dict[str, Any]] = None,
44
- temperature: float = 1.0,
45
- ) -> Tuple[float, str]:
46
  """
47
- Calculate perplexity of generated audio codes conditioned on caption/lyrics/metadata.
48
-
49
- This reverses the generation task: given audio codes as input, measure how well
50
- the model can predict the CoT metadata and lyrics that should generate those codes.
51
 
52
- Lower perplexity = model is less surprised = better quality generation
53
- Score = -perplexity (higher is better)
54
-
55
- The understanding task format is:
56
- Input: <|audio_code_123|><|audio_code_456|>...
57
- Output: <think>\nmetadata_yaml\n</think>\n\n# Lyric\nlyrics_text
58
 
59
  Args:
60
- llm_handler: LLM handler instance with initialized model
61
- audio_codes: Generated audio code string (e.g., "<|audio_code_123|><|audio_code_456|>...")
62
- caption: Caption text used for generation
63
- lyrics: Lyrics text used for generation
64
- metadata: Dictionary with CoT metadata fields (bpm, duration, keyscale, language, timesignature, etc.)
65
- temperature: Temperature for probability scaling (default 1.0)
66
 
67
  Returns:
68
- Tuple of (perplexity_value, status_message)
 
 
 
69
 
70
- Example:
71
- metadata = {'bpm': 120, 'duration': 30, 'keyscale': 'C major', 'language': 'en', 'timesignature': '4'}
72
- perplexity, status = calculate_perplexity(
73
- llm_handler,
74
- audio_codes="<|audio_code_123|>...",
75
- caption="calm piano",
76
- lyrics="verse 1...",
77
- metadata=metadata
78
- )
79
- score = -perplexity # Higher score = better quality
80
  """
81
- if not llm_handler.llm_initialized:
82
- return float('inf'), "❌ LLM not initialized"
83
-
84
- if not audio_codes or not audio_codes.strip():
85
- return float('inf'), "❌ No audio codes provided"
86
-
87
- try:
88
- # Build the understanding prompt: codes as input
89
- # The model should generate: <think>metadata</think>\n# Lyric\n...
90
- formatted_prompt = llm_handler.build_formatted_prompt_for_understanding(
91
- audio_codes=audio_codes,
92
- is_negative_prompt=False
93
- )
94
-
95
- logger.info(f"Calculating perplexity for {len(audio_codes)} character audio codes")
96
-
97
- # Build the expected output (target sequence) following understanding task format
98
- # Format: <think>\nmetadata_yaml\n</think>\n\n# Lyric\nlyrics_text
99
- target_parts = []
100
-
101
- # Build CoT section with metadata
102
- if metadata and isinstance(metadata, dict):
103
- # Filter out None values and format as YAML (sorted keys)
104
- cot_items = {}
105
- for key in ['bpm', 'caption', 'duration', 'genres', 'keyscale', 'language', 'timesignature']:
106
- if key in metadata and metadata[key] is not None:
107
- cot_items[key] = metadata[key]
108
-
109
- if cot_items:
110
- cot_yaml = yaml.dump(cot_items, allow_unicode=True, sort_keys=True).strip()
111
- target_parts.append(f"<think>\n{cot_yaml}\n</think>\n")
112
-
113
- # Add Lyric section (note: understanding task uses "# Lyric" not "# Caption")
114
- if lyrics:
115
- target_parts.append(f"\n# Lyric\n{lyrics}\n")
116
-
117
- target_text = "".join(target_parts)
118
-
119
- if not target_text.strip():
120
- return float('inf'), "❌ No target text to evaluate (lyrics or metadata required)"
121
-
122
- logger.debug(f"Target text (first 200 chars): {target_text[:200]}...")
123
-
124
- # Calculate perplexity using appropriate backend
125
- if llm_handler.llm_backend == "vllm":
126
- perplexity = _calculate_perplexity_vllm(
127
- llm_handler,
128
- formatted_prompt,
129
- target_text,
130
- temperature
131
- )
132
- else: # pt backend
133
- perplexity = _calculate_perplexity_pt(
134
- llm_handler,
135
- formatted_prompt,
136
- target_text,
137
- temperature
138
- )
139
-
140
- status_msg = f"✅ Perplexity calculated: {perplexity:.4f}"
141
- logger.info(status_msg)
142
- return perplexity, status_msg
143
-
144
- except Exception as e:
145
- error_msg = f"❌ Error calculating perplexity: {str(e)}"
146
- logger.error(error_msg)
147
- import traceback
148
- logger.error(traceback.format_exc())
149
- return float('inf'), error_msg
150
 
151
 
152
- def _calculate_perplexity_pt(
153
- llm_handler,
154
- formatted_prompt: str,
155
- target_text: str,
156
- temperature: float
157
- ) -> float:
158
  """
159
- Calculate perplexity using PyTorch backend.
160
-
161
- For vllm backend, this uses a shared-weight HuggingFace model.
162
- For pt backend, this uses the original model.
163
-
164
  Args:
165
- llm_handler: LLM handler with pt or vllm backend
166
- formatted_prompt: Formatted input prompt (audio codes)
167
- target_text: Expected output text (CoT metadata + lyrics)
168
- temperature: Temperature for probability scaling
169
 
170
  Returns:
171
- Perplexity value
 
 
172
  """
173
- # Get model for scoring (handles both pt and vllm backends)
174
  model = llm_handler.get_hf_model_for_scoring()
175
  tokenizer = llm_handler.llm_tokenizer
176
  device = llm_handler.device if llm_handler.llm_backend == "pt" else next(model.parameters()).device
177
-
178
- # Tokenize prompt and target separately
179
- prompt_tokens = tokenizer(
180
- formatted_prompt,
181
- return_tensors="pt",
182
- padding=False,
183
- truncation=True,
184
- )
185
-
186
- target_tokens = tokenizer(
187
- target_text,
188
- return_tensors="pt",
189
- padding=False,
190
- truncation=True,
191
- )
192
-
193
- # Concatenate prompt + target for full sequence
194
- full_input_ids = torch.cat([
195
- prompt_tokens['input_ids'],
196
- target_tokens['input_ids']
197
- ], dim=1).to(device)
198
-
199
- # Create attention mask
200
- attention_mask = torch.ones_like(full_input_ids)
201
-
202
- # Forward pass to get logits
203
  with torch.no_grad():
204
  with llm_handler._load_model_context():
205
- outputs = model(
206
- input_ids=full_input_ids,
207
- attention_mask=attention_mask
208
- )
209
- logits = outputs.logits # [batch_size, seq_len, vocab_size]
210
-
211
- # Get the logits for predicting target tokens
212
- # Shift logits and labels: logits[i] predicts token[i+1]
213
- prompt_len = prompt_tokens['input_ids'].shape[1]
214
- target_len = target_tokens['input_ids'].shape[1]
215
-
216
- # Extract logits for positions that predict target tokens
217
- # logits at positions [prompt_len-1 : prompt_len+target_len-1] predict target tokens
218
- pred_logits = logits[0, prompt_len-1:prompt_len+target_len-1, :] # [target_len, vocab_size]
219
- target_ids = target_tokens['input_ids'][0] # [target_len]
220
-
221
- # Apply temperature scaling
222
- if temperature != 1.0:
223
- pred_logits = pred_logits / temperature
224
-
225
- # Calculate cross-entropy loss for each position
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  log_probs = F.log_softmax(pred_logits, dim=-1) # [target_len, vocab_size]
227
-
228
- # Gather log probabilities of target tokens
229
- target_log_probs = log_probs[torch.arange(target_len), target_ids] # [target_len]
230
-
231
- # Calculate perplexity: exp(-mean(log_probs))
232
- mean_neg_log_prob = -target_log_probs.mean()
233
- perplexity = torch.exp(mean_neg_log_prob).item()
234
-
235
- return perplexity
236
 
 
 
 
 
 
237
 
238
- def _calculate_perplexity_vllm(
 
 
 
 
 
 
 
 
239
  llm_handler,
240
- formatted_prompt: str,
241
- target_text: str,
242
- temperature: float
243
- ) -> float:
 
 
 
 
244
  """
245
- Calculate perplexity using vllm backend.
246
-
247
- Uses shared-weight HuggingFace model for perplexity calculation.
248
- This avoids the complexity of nanovllm's context management.
249
-
250
- Args:
251
- llm_handler: LLM handler with vllm backend
252
- formatted_prompt: Formatted input prompt (audio codes)
253
- target_text: Expected output text (CoT metadata + lyrics)
254
- temperature: Temperature for probability scaling
255
-
256
- Returns:
257
- Perplexity value
258
  """
259
- logger.debug("Using vllm backend with shared-weight HuggingFace model for perplexity")
260
- # Delegate to pt backend implementation which now handles both backends
261
- return _calculate_perplexity_pt(llm_handler, formatted_prompt, target_text, temperature)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  """
5
  import torch
6
  import torch.nn.functional as F
7
+ from typing import Tuple, Optional, Dict, Any, List
8
  from loguru import logger
9
  import yaml
10
+ import math
11
+ import re
12
 
13
 
14
+ def pmi_score(log_prob_conditional: float, log_prob_unconditional: float) -> float:
15
  """
16
+ Calculate Pointwise Mutual Information (PMI) score.
17
 
18
+ PMI = log P(condition|codes) - log P(condition)
19
+ = log [P(codes|condition) / P(codes)]
20
+
21
+ This removes the bias from P(condition) and measures how much the codes
22
+ improve our ability to predict the condition.
23
 
24
  Args:
25
+ log_prob_conditional: Average log probability of condition given codes
26
+ log_prob_unconditional: Average log probability of condition without codes
 
 
27
 
28
  Returns:
29
+ PMI score (higher is better, can be positive or negative)
30
+ - Positive: codes improve prediction → good match
31
+ - Zero: codes don't help → no correlation
32
+ - Negative: codes hurt prediction poor match
 
 
 
33
  """
34
+ return log_prob_conditional - log_prob_unconditional
 
35
 
36
 
37
+ def pmi_to_normalized_score(pmi: float, scale: float = 0.1) -> float:
 
 
 
 
 
 
 
38
  """
39
+ Convert PMI score to normalized [0, 1] range using sigmoid function.
 
 
 
40
 
41
+ score = sigmoid(PMI / scale) = 1 / (1 + exp(-PMI / scale))
 
 
 
 
 
42
 
43
  Args:
44
+ pmi: PMI score (can be positive or negative)
45
+ scale: Scale parameter to control sensitivity (default 0.1)
46
+ - Smaller scale: more sensitive to PMI changes
47
+ - Larger scale: less sensitive to PMI changes
 
 
48
 
49
  Returns:
50
+ Normalized score in [0, 1] range, where:
51
+ - PMI > 0 → score > 0.5 (good match)
52
+ - PMI = 0 → score = 0.5 (neutral)
53
+ - PMI < 0 → score < 0.5 (poor match)
54
 
55
+ Examples (scale=1.0):
56
+ PMI=2.0 → score≈0.88 (excellent)
57
+ PMI=1.0 → score≈0.73 (good)
58
+ PMI=0.0 → score=0.50 (neutral)
59
+ PMI=-1.0 → score≈0.27 (poor)
60
+ PMI=-2.0 → score≈0.12 (bad)
 
 
 
 
61
  """
62
+ return 1.0 / (1.0 + math.exp(-pmi / scale))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
 
65
+ def _get_logits_and_target_for_scoring(llm_handler, formatted_prompt: str,
66
+ target_text: str) -> Tuple[torch.Tensor, torch.Tensor]:
 
 
 
 
67
  """
 
 
 
 
 
68
  Args:
69
+ llm_handler: The handler containing the model and tokenizer.
70
+ formatted_prompt: The input context.
71
+ target_text: The text we want to calculate probability/recall for.
 
72
 
73
  Returns:
74
+ Tuple of (target_logits, target_ids)
75
+ - target_logits: Logits used to predict the target tokens.
76
+ - target_ids: The ground truth token IDs of the target.
77
  """
 
78
  model = llm_handler.get_hf_model_for_scoring()
79
  tokenizer = llm_handler.llm_tokenizer
80
  device = llm_handler.device if llm_handler.llm_backend == "pt" else next(model.parameters()).device
81
+
82
+ # 1. Tokenize prompt ONLY to get its length (used for slicing later).
83
+ # We must ensure special tokens are added to count the offset correctly.
84
+ prompt_tokens_temp = tokenizer(formatted_prompt, return_tensors="pt", add_special_tokens=True)
85
+ prompt_len = prompt_tokens_temp['input_ids'].shape[1]
86
+
87
+ # 2. Tokenize the FULL text (Prompt + Target).
88
+ # This ensures subword merging at boundaries is handled correctly by the tokenizer.
89
+ full_text = formatted_prompt + target_text
90
+ full_tokens = tokenizer(full_text, return_tensors="pt", padding=False, truncation=True, add_special_tokens=True).to(device)
91
+
92
+ input_ids = full_tokens['input_ids']
93
+
94
+ # Safety check: if target was empty or truncated entirely
95
+ if input_ids.shape[1] <= prompt_len:
96
+ return torch.empty(0, device=device), torch.empty(0, device=device)
97
+
98
+ # 3. Forward Pass (Teacher Forcing)
 
 
 
 
 
 
 
 
99
  with torch.no_grad():
100
  with llm_handler._load_model_context():
101
+ outputs = model(input_ids=input_ids, attention_mask=full_tokens['attention_mask'])
102
+ all_logits = outputs.logits # [1, seq_len, vocab_size]
103
+
104
+ # 4. Extract Logits and Labels
105
+ # We need to predict `input_ids[i]`. The logit for this is at `all_logits[i-1]`.
106
+ # Target starts at index `prompt_len`.
107
+ # So we need logits from `prompt_len - 1` up to the second to last position.
108
+
109
+ target_logits = all_logits[0, prompt_len - 1:-1, :] # [target_len, vocab_size]
110
+ target_ids = input_ids[0, prompt_len:] # [target_len]
111
+
112
+ return target_logits, target_ids
113
+
114
+
115
+ # ==============================================================================
116
+ # Scoring Logic
117
+ # ==============================================================================
118
+
119
+
120
+ def _calculate_topk_recall(llm_handler,
121
+ formatted_prompt: str,
122
+ target_text: str,
123
+ topk: int = 10) -> Tuple[float, Dict[int, float]]:
124
+ """
125
+ Calculate top-k recall for target text given prompt.
126
+ Checks if the ground truth token is within the top-k probabilities at each step.
127
+ """
128
+ # Use the fixed helper to get aligned logits/labels
129
+ pred_logits, target_ids = _get_logits_and_target_for_scoring(llm_handler, formatted_prompt, target_text)
130
+
131
+ if target_ids.shape[0] == 0:
132
+ return 0.0, {}
133
+
134
+ target_len = target_ids.shape[0]
135
+
136
+ # Get top-k indices for all positions at once
137
+ # topk_indices: [target_len, topk]
138
+ _, topk_indices = torch.topk(pred_logits, k=min(topk, pred_logits.shape[-1]), dim=-1)
139
+
140
+ recall_per_k = {}
141
+ position_scores = []
142
+
143
+ # Convert to list for faster CPU iteration
144
+ target_ids_list = target_ids.tolist()
145
+ topk_indices_list = topk_indices.tolist()
146
+
147
+ for k in range(1, topk + 1):
148
+ hits = 0
149
+ for pos in range(target_len):
150
+ gt_token = target_ids_list[pos]
151
+ # Check the top-k slice
152
+ topk_at_pos = topk_indices_list[pos][:k]
153
+
154
+ if gt_token in topk_at_pos:
155
+ hits += 1
156
+ # Calculate position-weighted score only once (when k=topk)
157
+ if k == topk:
158
+ rank = topk_at_pos.index(gt_token) + 1
159
+ # Rank 1 = 1.0, Rank k = small positive
160
+ position_weight = 1.0 - (rank - 1) / topk
161
+ position_scores.append(position_weight)
162
+
163
+ recall_per_k[k] = hits / target_len if target_len > 0 else 0.0
164
+
165
+ # Fill scores for positions where GT was NOT in top-k
166
+ while len(position_scores) < target_len:
167
+ position_scores.append(0.0)
168
+
169
+ average_recall = sum(position_scores) / len(position_scores) if position_scores else 0.0
170
+
171
+ return average_recall, recall_per_k
172
+
173
+
174
+ def _calculate_metadata_recall(llm_handler,
175
+ formatted_prompt: str,
176
+ fields_dict: Dict[str, Any],
177
+ topk: int = 10) -> Dict[str, float]:
178
+ """
179
+ Args:
180
+ fields_dict: Dictionary of {field_name: field_value}
181
+ """
182
+ if not fields_dict:
183
+ return {}
184
+
185
+ field_scores = {}
186
+
187
+ for field_name in sorted(fields_dict.keys()):
188
+ # Construct target text for this specific field
189
+ # e.g. <think>\nbpm: 120\n</think>\n
190
+ field_yaml = yaml.dump({field_name: fields_dict[field_name]}, allow_unicode=True, sort_keys=True).strip()
191
+ field_target_text = f"<think>\n{field_yaml}\n</think>\n"
192
+
193
+ # Calculate recall using the robust logic
194
+ avg_score, _ = _calculate_topk_recall(llm_handler, formatted_prompt, field_target_text, topk=topk)
195
+
196
+ field_scores[field_name] = avg_score
197
+ logger.debug(f"Recall for {field_name}: {avg_score:.4f}")
198
+
199
+ return field_scores
200
+
201
+
202
+ def _calculate_log_prob(
203
+ llm_handler,
204
+ formatted_prompt: str,
205
+ target_text: str,
206
+ temperature: float = 1.0 # Kept for API compatibility, but ignored for scoring
207
+ ) -> float:
208
+ """
209
+ Calculate average log probability of target text given prompt.
210
+ """
211
+ pred_logits, target_ids = _get_logits_and_target_for_scoring(llm_handler, formatted_prompt, target_text)
212
+
213
+ if target_ids.shape[0] == 0:
214
+ return float('-inf')
215
+
216
+ # FIX: Do not divide by temperature.
217
+ # Log-probability for PMI/Perplexity should be exact.
218
+
219
+ # Calculate log probabilities (log_softmax)
220
  log_probs = F.log_softmax(pred_logits, dim=-1) # [target_len, vocab_size]
 
 
 
 
 
 
 
 
 
221
 
222
+ # Gather log probabilities of the ground truth tokens
223
+ target_log_probs = log_probs[torch.arange(target_ids.shape[0]), target_ids]
224
+
225
+ # Return average log probability
226
+ mean_log_prob = target_log_probs.mean().item()
227
 
228
+ return mean_log_prob
229
+
230
+
231
+ # ==============================================================================
232
+ # Main Public API
233
+ # ==============================================================================
234
+
235
+
236
+ def calculate_pmi_score_per_condition(
237
  llm_handler,
238
+ audio_codes: str,
239
+ caption: str = "",
240
+ lyrics: str = "",
241
+ metadata: Optional[Dict[str, Any]] = None,
242
+ temperature: float = 1.0,
243
+ topk: int = 10,
244
+ score_scale: float = 0.1,
245
+ ) -> Tuple[Dict[str, float], float, str]:
246
  """
247
+ Calculate quality score separately for each condition.
248
+ - Metadata: Uses Top-k Recall.
249
+ - Caption/Lyrics: Uses PMI (Normalized).
 
 
 
 
 
 
 
 
 
 
250
  """
251
+ if not llm_handler.llm_initialized:
252
+ return {}, 0.0, "❌ LLM not initialized"
253
+
254
+ if not audio_codes or not audio_codes.strip():
255
+ return {}, 0.0, "❌ No audio codes provided"
256
+
257
+ if "caption" not in metadata:
258
+ metadata['caption'] = caption
259
+
260
+ formatted_prompt = llm_handler.build_formatted_prompt_for_understanding(audio_codes=audio_codes, is_negative_prompt=False)
261
+ prompt_uncond = llm_handler.build_formatted_prompt_for_understanding(audio_codes="NO USER INPUT", is_negative_prompt=False)
262
+ try:
263
+ # 1. Calculate Recall for Metadata Fields
264
+ if metadata and isinstance(metadata, dict):
265
+ scores = {}
266
+ # Define which fields use which metric
267
+ metadata_recall_keys = ['bpm', 'duration', 'genres', 'keyscale', 'language', 'timesignature']
268
+ metadata_pmi_keys = ['caption']
269
+ for key in metadata_recall_keys:
270
+ if key in metadata and metadata[key] is not None:
271
+ recall_metadata = {key: metadata[key]}
272
+ field_scores = _calculate_metadata_recall(llm_handler, formatted_prompt, recall_metadata, topk=topk)
273
+ scores.update(field_scores)
274
+
275
+ # 2. Calculate PMI for Caption
276
+ for key in metadata_pmi_keys:
277
+ if key in metadata and metadata[key] is not None:
278
+ cot_yaml = yaml.dump({key: metadata[key]}, allow_unicode=True, sort_keys=True).strip()
279
+ target_text = f"<think>\n{cot_yaml}\n</think>\n"
280
+
281
+ log_prob_cond = _calculate_log_prob(llm_handler, formatted_prompt, target_text)
282
+ log_prob_uncond = _calculate_log_prob(llm_handler, prompt_uncond, target_text)
283
+
284
+ pmi_normalized = pmi_to_normalized_score(log_prob_cond - log_prob_uncond, scale=score_scale)
285
+ scores[key] = pmi_normalized
286
+
287
+ # 3. Calculate PMI for Lyrics
288
+ if lyrics:
289
+ target_text = f"<think>\n</think>\n# Lyric\n{lyrics}\n"
290
+
291
+ log_prob_cond = _calculate_log_prob(llm_handler, formatted_prompt, target_text)
292
+
293
+ prompt_uncond = llm_handler.build_formatted_prompt_for_understanding(audio_codes="NO USER INPUT", is_negative_prompt=False)
294
+ log_prob_uncond = _calculate_log_prob(llm_handler, prompt_uncond, target_text)
295
+
296
+ scores['lyrics'] = pmi_to_normalized_score(log_prob_cond - log_prob_uncond, scale=score_scale)
297
+
298
+ if not scores:
299
+ return {}, 0.0, "❌ No conditions to evaluate"
300
+
301
+ # 4. Global Score
302
+ global_score = sum(scores.values()) / len(scores)
303
+
304
+ # Status Message
305
+ status_lines = ["✅ Per-condition scores (0-1):"]
306
+ for key, score in sorted(scores.items()):
307
+ metric = "Top-k Recall" if key in metadata_recall_keys else "PMI (Norm)"
308
+ status_lines.append(f" {key}: {score:.4f} ({metric})")
309
+ status_lines.append(f"Global score: {global_score:.4f}")
310
+
311
+ logger.info(f"Calculated scores: {global_score:.4f}")
312
+ return scores, global_score, "\n".join(status_lines)
313
+
314
+ except Exception as e:
315
+ import traceback
316
+ error_msg = f"❌ Error: {str(e)}"
317
+ logger.error(error_msg)
318
+ logger.error(traceback.format_exc())
319
+ return {}, float('-inf'), error_msg