|
|
""" |
|
|
Test-Time Scaling Module |
|
|
Implements perplexity-based scoring for generated audio codes |
|
|
""" |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from typing import Tuple, Optional, Dict, Any, List |
|
|
from loguru import logger |
|
|
import yaml |
|
|
import math |
|
|
import re |
|
|
|
|
|
|
|
|
def pmi_score(log_prob_conditional: float, log_prob_unconditional: float) -> float: |
|
|
""" |
|
|
Calculate Pointwise Mutual Information (PMI) score. |
|
|
|
|
|
PMI = log P(condition|codes) - log P(condition) |
|
|
= log [P(codes|condition) / P(codes)] |
|
|
|
|
|
This removes the bias from P(condition) and measures how much the codes |
|
|
improve our ability to predict the condition. |
|
|
|
|
|
Args: |
|
|
log_prob_conditional: Average log probability of condition given codes |
|
|
log_prob_unconditional: Average log probability of condition without codes |
|
|
|
|
|
Returns: |
|
|
PMI score (higher is better, can be positive or negative) |
|
|
- Positive: codes improve prediction → good match |
|
|
- Zero: codes don't help → no correlation |
|
|
- Negative: codes hurt prediction → poor match |
|
|
""" |
|
|
return log_prob_conditional - log_prob_unconditional |
|
|
|
|
|
|
|
|
def pmi_to_normalized_score(pmi: float, scale: float = 0.1) -> float: |
|
|
""" |
|
|
Convert PMI score to normalized [0, 1] range using sigmoid function. |
|
|
|
|
|
score = sigmoid(PMI / scale) = 1 / (1 + exp(-PMI / scale)) |
|
|
|
|
|
Args: |
|
|
pmi: PMI score (can be positive or negative) |
|
|
scale: Scale parameter to control sensitivity (default 0.1) |
|
|
- Smaller scale: more sensitive to PMI changes |
|
|
- Larger scale: less sensitive to PMI changes |
|
|
|
|
|
Returns: |
|
|
Normalized score in [0, 1] range, where: |
|
|
- PMI > 0 → score > 0.5 (good match) |
|
|
- PMI = 0 → score = 0.5 (neutral) |
|
|
- PMI < 0 → score < 0.5 (poor match) |
|
|
|
|
|
Examples (scale=1.0): |
|
|
PMI=2.0 → score≈0.88 (excellent) |
|
|
PMI=1.0 → score≈0.73 (good) |
|
|
PMI=0.0 → score=0.50 (neutral) |
|
|
PMI=-1.0 → score≈0.27 (poor) |
|
|
PMI=-2.0 → score≈0.12 (bad) |
|
|
""" |
|
|
return 1.0 / (1.0 + math.exp(-pmi / scale)) |
|
|
|
|
|
|
|
|
def _get_logits_and_target_for_scoring(llm_handler, formatted_prompt: str, |
|
|
target_text: str) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Args: |
|
|
llm_handler: The handler containing the model and tokenizer. |
|
|
formatted_prompt: The input context. |
|
|
target_text: The text we want to calculate probability/recall for. |
|
|
|
|
|
Returns: |
|
|
Tuple of (target_logits, target_ids) |
|
|
- target_logits: Logits used to predict the target tokens. |
|
|
- target_ids: The ground truth token IDs of the target. |
|
|
""" |
|
|
model = llm_handler.get_hf_model_for_scoring() |
|
|
tokenizer = llm_handler.llm_tokenizer |
|
|
device = llm_handler.device if llm_handler.llm_backend == "pt" else next(model.parameters()).device |
|
|
|
|
|
|
|
|
|
|
|
prompt_tokens_temp = tokenizer(formatted_prompt, return_tensors="pt", add_special_tokens=True) |
|
|
prompt_len = prompt_tokens_temp['input_ids'].shape[1] |
|
|
|
|
|
|
|
|
|
|
|
full_text = formatted_prompt + target_text |
|
|
full_tokens = tokenizer(full_text, return_tensors="pt", padding=False, truncation=True, add_special_tokens=True).to(device) |
|
|
|
|
|
input_ids = full_tokens['input_ids'] |
|
|
|
|
|
|
|
|
if input_ids.shape[1] <= prompt_len: |
|
|
return torch.empty(0, device=device), torch.empty(0, device=device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
with llm_handler._load_model_context(): |
|
|
outputs = model(input_ids=input_ids, attention_mask=full_tokens['attention_mask']) |
|
|
all_logits = outputs.logits |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
target_logits = all_logits[0, prompt_len - 1:-1, :] |
|
|
target_ids = input_ids[0, prompt_len:] |
|
|
|
|
|
return target_logits, target_ids |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _calculate_topk_recall(llm_handler, |
|
|
formatted_prompt: str, |
|
|
target_text: str, |
|
|
topk: int = 10) -> Tuple[float, Dict[int, float]]: |
|
|
""" |
|
|
Calculate top-k recall for target text given prompt. |
|
|
Checks if the ground truth token is within the top-k probabilities at each step. |
|
|
""" |
|
|
|
|
|
pred_logits, target_ids = _get_logits_and_target_for_scoring(llm_handler, formatted_prompt, target_text) |
|
|
|
|
|
if target_ids.shape[0] == 0: |
|
|
return 0.0, {} |
|
|
|
|
|
target_len = target_ids.shape[0] |
|
|
|
|
|
|
|
|
|
|
|
_, topk_indices = torch.topk(pred_logits, k=min(topk, pred_logits.shape[-1]), dim=-1) |
|
|
|
|
|
recall_per_k = {} |
|
|
position_scores = [] |
|
|
|
|
|
|
|
|
target_ids_list = target_ids.tolist() |
|
|
topk_indices_list = topk_indices.tolist() |
|
|
|
|
|
for k in range(1, topk + 1): |
|
|
hits = 0 |
|
|
for pos in range(target_len): |
|
|
gt_token = target_ids_list[pos] |
|
|
|
|
|
topk_at_pos = topk_indices_list[pos][:k] |
|
|
|
|
|
if gt_token in topk_at_pos: |
|
|
hits += 1 |
|
|
|
|
|
if k == topk: |
|
|
rank = topk_at_pos.index(gt_token) + 1 |
|
|
|
|
|
position_weight = 1.0 - (rank - 1) / topk |
|
|
position_scores.append(position_weight) |
|
|
|
|
|
recall_per_k[k] = hits / target_len if target_len > 0 else 0.0 |
|
|
|
|
|
|
|
|
while len(position_scores) < target_len: |
|
|
position_scores.append(0.0) |
|
|
|
|
|
average_recall = sum(position_scores) / len(position_scores) if position_scores else 0.0 |
|
|
|
|
|
return average_recall, recall_per_k |
|
|
|
|
|
|
|
|
def _calculate_metadata_recall(llm_handler, |
|
|
formatted_prompt: str, |
|
|
fields_dict: Dict[str, Any], |
|
|
topk: int = 10) -> Dict[str, float]: |
|
|
""" |
|
|
Args: |
|
|
fields_dict: Dictionary of {field_name: field_value} |
|
|
""" |
|
|
if not fields_dict: |
|
|
return {} |
|
|
|
|
|
field_scores = {} |
|
|
|
|
|
for field_name in sorted(fields_dict.keys()): |
|
|
|
|
|
|
|
|
field_yaml = yaml.dump({field_name: fields_dict[field_name]}, allow_unicode=True, sort_keys=True).strip() |
|
|
field_target_text = f"<think>\n{field_yaml}\n</think>\n" |
|
|
|
|
|
|
|
|
avg_score, _ = _calculate_topk_recall(llm_handler, formatted_prompt, field_target_text, topk=topk) |
|
|
|
|
|
field_scores[field_name] = avg_score |
|
|
logger.debug(f"Recall for {field_name}: {avg_score:.4f}") |
|
|
|
|
|
return field_scores |
|
|
|
|
|
|
|
|
def _calculate_log_prob( |
|
|
llm_handler, |
|
|
formatted_prompt: str, |
|
|
target_text: str, |
|
|
temperature: float = 1.0 |
|
|
) -> float: |
|
|
""" |
|
|
Calculate average log probability of target text given prompt. |
|
|
""" |
|
|
pred_logits, target_ids = _get_logits_and_target_for_scoring(llm_handler, formatted_prompt, target_text) |
|
|
|
|
|
if target_ids.shape[0] == 0: |
|
|
return float('-inf') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
log_probs = F.log_softmax(pred_logits, dim=-1) |
|
|
|
|
|
|
|
|
target_log_probs = log_probs[torch.arange(target_ids.shape[0]), target_ids] |
|
|
|
|
|
|
|
|
mean_log_prob = target_log_probs.mean().item() |
|
|
|
|
|
return mean_log_prob |
|
|
|
|
|
|
|
|
def calculate_reward_score( |
|
|
scores: Dict[str, float], |
|
|
weights_config: Optional[Dict[str, float]] = None |
|
|
) -> Tuple[float, str]: |
|
|
""" |
|
|
Reward Model Calculator: Computes a final reward based on user priorities. |
|
|
|
|
|
Priority Logic: |
|
|
1. Caption (Highest): The overall vibe/style must match. |
|
|
2. Lyrics (Medium): Content accuracy is important but secondary to vibe. |
|
|
3. Metadata (Lowest): Technical constraints (BPM, Key) allow for slight deviations. |
|
|
|
|
|
Strategy: Dynamic Weighted Sum |
|
|
- Metadata fields are aggregated into a single 'metadata' score first. |
|
|
- Weights are dynamically renormalized if any component (e.g., lyrics) is missing. |
|
|
|
|
|
Args: |
|
|
scores: Dictionary of raw scores (0.0 - 1.0) from the evaluation module. |
|
|
weights_config: Optional custom weights. Defaults to: |
|
|
Caption (50%), Lyrics (30%), Metadata (20%). |
|
|
|
|
|
Returns: |
|
|
final_reward: The calculated reward score (0.0 - 1.0). |
|
|
explanation: A formatted string explaining how the score was derived. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
if weights_config is None: |
|
|
weights_config = { |
|
|
'caption': 0.50, |
|
|
'lyrics': 0.30, |
|
|
'metadata': 0.20 |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
caption_score = scores.get('caption') |
|
|
lyrics_score = scores.get('lyrics') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
meta_scores_list = [ |
|
|
val for key, val in scores.items() |
|
|
if key not in ['caption', 'lyrics'] |
|
|
] |
|
|
|
|
|
|
|
|
meta_aggregate_score = None |
|
|
if meta_scores_list: |
|
|
meta_aggregate_score = sum(meta_scores_list) / len(meta_scores_list) |
|
|
|
|
|
|
|
|
|
|
|
active_components = {} |
|
|
|
|
|
if caption_score is not None: |
|
|
active_components['caption'] = (caption_score, weights_config['caption']) |
|
|
|
|
|
if lyrics_score is not None: |
|
|
active_components['lyrics'] = (lyrics_score, weights_config['lyrics']) |
|
|
|
|
|
if meta_aggregate_score is not None: |
|
|
active_components['metadata'] = (meta_aggregate_score, weights_config['metadata']) |
|
|
|
|
|
|
|
|
total_base_weight = sum(w for _, w in active_components.values()) |
|
|
total_score = 0.0 |
|
|
|
|
|
breakdown_lines = [] |
|
|
|
|
|
if total_base_weight == 0: |
|
|
return 0.0, "❌ No valid scores available to calculate reward." |
|
|
|
|
|
|
|
|
sorted_components = sorted(active_components.items(), key=lambda x: x[1][1], reverse=True) |
|
|
|
|
|
for name, (score, base_weight) in sorted_components: |
|
|
|
|
|
normalized_weight = base_weight / total_base_weight |
|
|
weighted_contribution = score * normalized_weight |
|
|
total_score += weighted_contribution |
|
|
|
|
|
breakdown_lines.append( |
|
|
f" • {name.title():<8} | Score: {score:.4f} | Weight: {normalized_weight:.2f} " |
|
|
f"-> Contrib: +{weighted_contribution:.4f}" |
|
|
) |
|
|
|
|
|
return total_score, "\n".join(breakdown_lines) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def calculate_pmi_score_per_condition( |
|
|
llm_handler, |
|
|
audio_codes: str, |
|
|
caption: str = "", |
|
|
lyrics: str = "", |
|
|
metadata: Optional[Dict[str, Any]] = None, |
|
|
temperature: float = 1.0, |
|
|
topk: int = 10, |
|
|
score_scale: float = 0.1, |
|
|
) -> Tuple[Dict[str, float], float, str]: |
|
|
""" |
|
|
Calculate quality score separately for each condition. |
|
|
- Metadata: Uses Top-k Recall. |
|
|
- Caption/Lyrics: Uses PMI (Normalized). |
|
|
""" |
|
|
if not llm_handler.llm_initialized: |
|
|
return {}, 0.0, "❌ LLM not initialized" |
|
|
|
|
|
if not audio_codes or not audio_codes.strip(): |
|
|
return {}, 0.0, "❌ No audio codes provided" |
|
|
|
|
|
if "caption" not in metadata: |
|
|
metadata['caption'] = caption |
|
|
|
|
|
formatted_prompt = llm_handler.build_formatted_prompt_for_understanding(audio_codes=audio_codes, is_negative_prompt=False) |
|
|
prompt_uncond = llm_handler.build_formatted_prompt_for_understanding(audio_codes="NO USER INPUT", is_negative_prompt=False) |
|
|
try: |
|
|
|
|
|
if metadata and isinstance(metadata, dict): |
|
|
scores = {} |
|
|
|
|
|
metadata_recall_keys = ['bpm', 'duration', 'genres', 'keyscale', 'language', 'timesignature'] |
|
|
metadata_pmi_keys = ['caption'] |
|
|
for key in metadata_recall_keys: |
|
|
if key in metadata and metadata[key] is not None: |
|
|
recall_metadata = {key: metadata[key]} |
|
|
field_scores = _calculate_metadata_recall(llm_handler, formatted_prompt, recall_metadata, topk=topk) |
|
|
scores.update(field_scores) |
|
|
|
|
|
|
|
|
for key in metadata_pmi_keys: |
|
|
if key in metadata and metadata[key] is not None: |
|
|
cot_yaml = yaml.dump({key: metadata[key]}, allow_unicode=True, sort_keys=True).strip() |
|
|
target_text = f"<think>\n{cot_yaml}\n</think>\n" |
|
|
|
|
|
log_prob_cond = _calculate_log_prob(llm_handler, formatted_prompt, target_text) |
|
|
log_prob_uncond = _calculate_log_prob(llm_handler, prompt_uncond, target_text) |
|
|
|
|
|
pmi_normalized = pmi_to_normalized_score(log_prob_cond - log_prob_uncond, scale=score_scale) |
|
|
scores[key] = pmi_normalized |
|
|
|
|
|
|
|
|
if lyrics: |
|
|
target_text = f"<think>\n</think>\n# Lyric\n{lyrics}\n" |
|
|
|
|
|
log_prob_cond = _calculate_log_prob(llm_handler, formatted_prompt, target_text) |
|
|
|
|
|
prompt_uncond = llm_handler.build_formatted_prompt_for_understanding(audio_codes="NO USER INPUT", is_negative_prompt=False) |
|
|
log_prob_uncond = _calculate_log_prob(llm_handler, prompt_uncond, target_text) |
|
|
|
|
|
scores['lyrics'] = pmi_to_normalized_score(log_prob_cond - log_prob_uncond, scale=score_scale) |
|
|
|
|
|
if not scores: |
|
|
return {}, 0.0, "❌ No conditions to evaluate" |
|
|
|
|
|
|
|
|
global_score = sum(scores.values()) / len(scores) |
|
|
global_score, breakdown_lines = calculate_reward_score(scores) |
|
|
|
|
|
|
|
|
status_lines = [breakdown_lines, "\n✅ Per-condition scores (0-1):"] |
|
|
for key, score in sorted(scores.items()): |
|
|
metric = "Top-k Recall" if key in metadata_recall_keys else "PMI (Norm)" |
|
|
status_lines.append(f" {key}: {score:.4f} ({metric})") |
|
|
status = "\n".join(status_lines) |
|
|
logger.info(f"Calculated scores: {global_score:.4f}\n{status}") |
|
|
return scores, global_score, status |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
error_msg = f"❌ Error: {str(e)}" |
|
|
logger.error(error_msg) |
|
|
logger.error(traceback.format_exc()) |
|
|
return {}, float('-inf'), error_msg |
|
|
|