|
|
""" |
|
|
5Hz LM (Language Model) Handler |
|
|
Handles all LM-related operations including initialization and generation |
|
|
""" |
|
|
import os |
|
|
import traceback |
|
|
import time |
|
|
import random |
|
|
from typing import Optional, Dict, Any, Tuple, List, Union |
|
|
from contextlib import contextmanager |
|
|
|
|
|
import yaml |
|
|
import torch |
|
|
from loguru import logger |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
from transformers.generation.streamers import BaseStreamer |
|
|
from transformers.generation.logits_process import ( |
|
|
LogitsProcessorList, |
|
|
RepetitionPenaltyLogitsProcessor, |
|
|
) |
|
|
from acestep.constrained_logits_processor import MetadataConstrainedLogitsProcessor |
|
|
from acestep.constants import DEFAULT_LM_INSTRUCTION, DEFAULT_LM_UNDERSTAND_INSTRUCTION, DEFAULT_LM_INSPIRED_INSTRUCTION, DEFAULT_LM_REWRITE_INSTRUCTION |
|
|
|
|
|
|
|
|
class LLMHandler: |
|
|
"""5Hz LM Handler for audio code generation""" |
|
|
|
|
|
STOP_REASONING_TAG = "</think>" |
|
|
|
|
|
def __init__(self): |
|
|
"""Initialize LLMHandler with default values""" |
|
|
self.llm = None |
|
|
self.llm_tokenizer = None |
|
|
self.llm_initialized = False |
|
|
self.llm_backend = None |
|
|
self.max_model_len = 4096 |
|
|
self.device = "cpu" |
|
|
self.dtype = torch.float32 |
|
|
self.offload_to_cpu = False |
|
|
|
|
|
|
|
|
self.constrained_processor: Optional[MetadataConstrainedLogitsProcessor] = None |
|
|
|
|
|
|
|
|
self._hf_model_for_scoring = None |
|
|
|
|
|
def get_available_5hz_lm_models(self) -> List[str]: |
|
|
"""Scan and return all model directory names starting with 'acestep-5Hz-lm-'""" |
|
|
current_file = os.path.abspath(__file__) |
|
|
project_root = os.path.dirname(os.path.dirname(current_file)) |
|
|
checkpoint_dir = os.path.join(project_root, "checkpoints") |
|
|
|
|
|
models = [] |
|
|
if os.path.exists(checkpoint_dir): |
|
|
for item in os.listdir(checkpoint_dir): |
|
|
item_path = os.path.join(checkpoint_dir, item) |
|
|
if os.path.isdir(item_path) and item.startswith("acestep-5Hz-lm-"): |
|
|
models.append(item) |
|
|
|
|
|
models.sort() |
|
|
return models |
|
|
|
|
|
def get_gpu_memory_utilization(self, minimal_gpu: float = 8, min_ratio: float = 0.2, max_ratio: float = 0.9) -> Tuple[float, bool]: |
|
|
"""Get GPU memory utilization ratio""" |
|
|
try: |
|
|
device = torch.device("cuda:0") |
|
|
total_gpu_mem_bytes = torch.cuda.get_device_properties(device).total_memory |
|
|
allocated_mem_bytes = torch.cuda.memory_allocated(device) |
|
|
reserved_mem_bytes = torch.cuda.memory_reserved(device) |
|
|
|
|
|
total_gpu = total_gpu_mem_bytes / 1024**3 |
|
|
low_gpu_memory_mode = False |
|
|
if total_gpu < minimal_gpu: |
|
|
minimal_gpu = 0.5 * total_gpu |
|
|
low_gpu_memory_mode = True |
|
|
allocated_gpu = allocated_mem_bytes / 1024**3 |
|
|
reserved_gpu = reserved_mem_bytes / 1024**3 |
|
|
available_gpu = total_gpu - reserved_gpu |
|
|
|
|
|
if available_gpu >= minimal_gpu: |
|
|
ratio = min(max_ratio, max(min_ratio, minimal_gpu / total_gpu)) |
|
|
else: |
|
|
ratio = min(max_ratio, max(min_ratio, (available_gpu * 0.8) / total_gpu)) |
|
|
|
|
|
return ratio, low_gpu_memory_mode |
|
|
except Exception as e: |
|
|
return 0.9, False |
|
|
|
|
|
def _has_meaningful_negative_prompt(self, negative_prompt: str) -> bool: |
|
|
"""Check if negative prompt is meaningful (not default/empty)""" |
|
|
return negative_prompt and negative_prompt.strip() and negative_prompt.strip() != "NO USER INPUT" |
|
|
|
|
|
def _build_logits_processor(self, repetition_penalty: float) -> LogitsProcessorList: |
|
|
"""Build logits processor list with repetition penalty if needed""" |
|
|
logits_processor = LogitsProcessorList() |
|
|
if repetition_penalty != 1.0: |
|
|
logits_processor.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)) |
|
|
return logits_processor |
|
|
|
|
|
def _setup_constrained_processor( |
|
|
self, |
|
|
use_constrained_decoding: bool, |
|
|
constrained_decoding_debug: bool, |
|
|
target_duration: Optional[float], |
|
|
user_metadata: Optional[Dict[str, Optional[str]]], |
|
|
stop_at_reasoning: bool, |
|
|
skip_genres: bool, |
|
|
skip_caption: bool, |
|
|
skip_language: bool, |
|
|
generation_phase: str, |
|
|
is_batch: bool = False, |
|
|
metadata_temperature: Optional[float] = None, |
|
|
codes_temperature: Optional[float] = None, |
|
|
) -> Optional[MetadataConstrainedLogitsProcessor]: |
|
|
"""Setup and configure constrained processor for generation""" |
|
|
use_phase_temperatures = not is_batch and (metadata_temperature is not None or codes_temperature is not None) |
|
|
|
|
|
if not use_constrained_decoding and not use_phase_temperatures: |
|
|
return None |
|
|
|
|
|
|
|
|
self.constrained_processor.reset() |
|
|
|
|
|
|
|
|
self.constrained_processor.enabled = use_constrained_decoding |
|
|
self.constrained_processor.debug = constrained_decoding_debug |
|
|
|
|
|
|
|
|
if use_phase_temperatures: |
|
|
self.constrained_processor.metadata_temperature = metadata_temperature |
|
|
self.constrained_processor.codes_temperature = codes_temperature |
|
|
else: |
|
|
self.constrained_processor.metadata_temperature = None |
|
|
self.constrained_processor.codes_temperature = None |
|
|
|
|
|
self.constrained_processor.set_target_duration(target_duration) |
|
|
|
|
|
|
|
|
if is_batch: |
|
|
self.constrained_processor.set_user_metadata(None) |
|
|
self.constrained_processor.set_stop_at_reasoning(False) |
|
|
self.constrained_processor.set_skip_genres(True) |
|
|
self.constrained_processor.set_skip_caption(True) |
|
|
self.constrained_processor.set_skip_language(True) |
|
|
else: |
|
|
|
|
|
self.constrained_processor.set_user_metadata(user_metadata) |
|
|
self.constrained_processor.set_stop_at_reasoning(stop_at_reasoning) |
|
|
self.constrained_processor.set_skip_genres(skip_genres) |
|
|
self.constrained_processor.set_skip_caption(skip_caption) |
|
|
self.constrained_processor.set_skip_language(skip_language) |
|
|
|
|
|
|
|
|
self.constrained_processor.set_generation_phase(generation_phase) |
|
|
|
|
|
return self.constrained_processor |
|
|
|
|
|
def _build_unconditional_prompt( |
|
|
self, |
|
|
caption: str, |
|
|
lyrics: str, |
|
|
cot_text: str, |
|
|
negative_prompt: str, |
|
|
generation_phase: str, |
|
|
is_batch: bool = False, |
|
|
) -> str: |
|
|
"""Build unconditional prompt for CFG based on generation phase and batch mode""" |
|
|
if is_batch or generation_phase == "codes": |
|
|
|
|
|
return self.build_formatted_prompt_with_cot( |
|
|
caption, lyrics, cot_text, is_negative_prompt=True, negative_prompt=negative_prompt |
|
|
) |
|
|
else: |
|
|
|
|
|
|
|
|
return self.build_formatted_prompt( |
|
|
caption, lyrics, is_negative_prompt=True, generation_phase="cot", negative_prompt=negative_prompt |
|
|
) |
|
|
|
|
|
def _load_pytorch_model(self, model_path: str, device: str) -> Tuple[bool, str]: |
|
|
"""Load PyTorch model from path and return (success, status_message)""" |
|
|
try: |
|
|
self.llm = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True) |
|
|
if not self.offload_to_cpu: |
|
|
self.llm = self.llm.to(device).to(self.dtype) |
|
|
else: |
|
|
self.llm = self.llm.to("cpu").to(self.dtype) |
|
|
self.llm.eval() |
|
|
self.llm_backend = "pt" |
|
|
self.llm_initialized = True |
|
|
logger.info(f"5Hz LM initialized successfully using PyTorch backend on {device}") |
|
|
status_msg = f"✅ 5Hz LM initialized successfully\nModel: {model_path}\nBackend: PyTorch\nDevice: {device}" |
|
|
return True, status_msg |
|
|
except Exception as e: |
|
|
return False, f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}" |
|
|
|
|
|
def _apply_top_k_filter(self, logits: torch.Tensor, top_k: Optional[int]) -> torch.Tensor: |
|
|
"""Apply top-k filtering to logits""" |
|
|
if top_k is not None and top_k > 0: |
|
|
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] |
|
|
logits[indices_to_remove] = float('-inf') |
|
|
return logits |
|
|
|
|
|
def _apply_top_p_filter(self, logits: torch.Tensor, top_p: Optional[float]) -> torch.Tensor: |
|
|
"""Apply top-p (nucleus) filtering to logits""" |
|
|
if top_p is not None and 0.0 < top_p < 1.0: |
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
|
|
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) |
|
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
|
|
sorted_indices_to_remove[..., 0] = 0 |
|
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) |
|
|
logits[indices_to_remove] = float('-inf') |
|
|
return logits |
|
|
|
|
|
def _sample_tokens(self, logits: torch.Tensor, temperature: float) -> torch.Tensor: |
|
|
"""Sample tokens from logits with temperature""" |
|
|
if temperature > 0: |
|
|
logits = logits / temperature |
|
|
probs = torch.softmax(logits, dim=-1) |
|
|
return torch.multinomial(probs, num_samples=1).squeeze(1) |
|
|
else: |
|
|
return torch.argmax(logits, dim=-1) |
|
|
|
|
|
def _check_eos_token(self, tokens: torch.Tensor, eos_token_id: int, pad_token_id: Optional[int]) -> bool: |
|
|
"""Check if any token in the batch is EOS or pad token""" |
|
|
if torch.any(tokens == eos_token_id): |
|
|
return True |
|
|
if pad_token_id is not None and pad_token_id != eos_token_id: |
|
|
if torch.any(tokens == pad_token_id): |
|
|
return True |
|
|
return False |
|
|
|
|
|
def _update_constrained_processor_state(self, constrained_processor: Optional[MetadataConstrainedLogitsProcessor], tokens: torch.Tensor): |
|
|
"""Update constrained processor state with generated tokens""" |
|
|
if constrained_processor is not None: |
|
|
for b in range(tokens.shape[0]): |
|
|
constrained_processor.update_state(tokens[b].item()) |
|
|
|
|
|
def _forward_pass( |
|
|
self, |
|
|
model: Any, |
|
|
generated_ids: torch.Tensor, |
|
|
model_kwargs: Dict[str, Any], |
|
|
past_key_values: Optional[Any], |
|
|
use_cache: bool, |
|
|
) -> Any: |
|
|
"""Perform forward pass with KV cache support""" |
|
|
if past_key_values is None: |
|
|
outputs = model( |
|
|
input_ids=generated_ids, |
|
|
**model_kwargs, |
|
|
use_cache=use_cache, |
|
|
) |
|
|
else: |
|
|
outputs = model( |
|
|
input_ids=generated_ids[:, -1:], |
|
|
past_key_values=past_key_values, |
|
|
**model_kwargs, |
|
|
use_cache=use_cache, |
|
|
) |
|
|
return outputs |
|
|
|
|
|
def _normalize_batch_input(self, formatted_prompts: Union[str, List[str]]) -> Tuple[List[str], bool]: |
|
|
"""Normalize batch input: convert single string to list and return (list, is_batch)""" |
|
|
is_batch = isinstance(formatted_prompts, list) |
|
|
if is_batch: |
|
|
return formatted_prompts, is_batch |
|
|
else: |
|
|
return [formatted_prompts], is_batch |
|
|
|
|
|
def initialize( |
|
|
self, |
|
|
checkpoint_dir: str, |
|
|
lm_model_path: str, |
|
|
backend: str = "vllm", |
|
|
device: str = "auto", |
|
|
offload_to_cpu: bool = False, |
|
|
dtype: Optional[torch.dtype] = None, |
|
|
) -> Tuple[str, bool]: |
|
|
""" |
|
|
Initialize 5Hz LM model |
|
|
|
|
|
Args: |
|
|
checkpoint_dir: Checkpoint directory path |
|
|
lm_model_path: LM model path (relative to checkpoint_dir) |
|
|
backend: Backend type ("vllm" or "pt") |
|
|
device: Device type ("auto", "cuda", or "cpu") |
|
|
offload_to_cpu: Whether to offload to CPU |
|
|
dtype: Data type (if None, auto-detect based on device) |
|
|
|
|
|
Returns: |
|
|
(status_message, success) |
|
|
""" |
|
|
try: |
|
|
if device == "auto": |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
self.device = device |
|
|
self.offload_to_cpu = offload_to_cpu |
|
|
|
|
|
if dtype is None: |
|
|
self.dtype = torch.bfloat16 if device in ["cuda", "xpu"] else torch.float32 |
|
|
else: |
|
|
self.dtype = dtype |
|
|
|
|
|
full_lm_model_path = os.path.join(checkpoint_dir, lm_model_path) |
|
|
if not os.path.exists(full_lm_model_path): |
|
|
return f"❌ 5Hz LM model not found at {full_lm_model_path}", False |
|
|
|
|
|
logger.info("loading 5Hz LM tokenizer... it may take 80~90s") |
|
|
start_time = time.time() |
|
|
|
|
|
llm_tokenizer = AutoTokenizer.from_pretrained(full_lm_model_path, use_fast=True) |
|
|
logger.info(f"5Hz LM tokenizer loaded successfully in {time.time() - start_time:.2f} seconds") |
|
|
self.llm_tokenizer = llm_tokenizer |
|
|
|
|
|
|
|
|
logger.info("Initializing constrained decoding processor...") |
|
|
processor_start = time.time() |
|
|
self.constrained_processor = MetadataConstrainedLogitsProcessor( |
|
|
tokenizer=self.llm_tokenizer, |
|
|
enabled=True, |
|
|
debug=False, |
|
|
) |
|
|
logger.info(f"Constrained processor initialized in {time.time() - processor_start:.2f} seconds") |
|
|
|
|
|
|
|
|
if backend == "vllm": |
|
|
|
|
|
status_msg = self._initialize_5hz_lm_vllm(full_lm_model_path) |
|
|
logger.info(f"5Hz LM status message: {status_msg}") |
|
|
|
|
|
if status_msg.startswith("❌"): |
|
|
|
|
|
if not self.llm_initialized: |
|
|
logger.warning("vllm initialization failed, falling back to PyTorch backend") |
|
|
success, status_msg = self._load_pytorch_model(full_lm_model_path, device) |
|
|
if not success: |
|
|
return status_msg, False |
|
|
status_msg = f"✅ 5Hz LM initialized successfully (PyTorch fallback)\nModel: {full_lm_model_path}\nBackend: PyTorch" |
|
|
|
|
|
else: |
|
|
|
|
|
success, status_msg = self._load_pytorch_model(full_lm_model_path, device) |
|
|
if not success: |
|
|
return status_msg, False |
|
|
|
|
|
return status_msg, True |
|
|
|
|
|
except Exception as e: |
|
|
return f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}", False |
|
|
|
|
|
def _initialize_5hz_lm_vllm(self, model_path: str) -> str: |
|
|
"""Initialize 5Hz LM model using vllm backend""" |
|
|
if not torch.cuda.is_available(): |
|
|
self.llm_initialized = False |
|
|
logger.error("CUDA is not available. Please check your GPU setup.") |
|
|
return "❌ CUDA is not available. Please check your GPU setup." |
|
|
try: |
|
|
from nanovllm import LLM, SamplingParams |
|
|
except ImportError: |
|
|
self.llm_initialized = False |
|
|
logger.error("nano-vllm is not installed. Please install it using 'cd acestep/third_parts/nano-vllm && pip install .") |
|
|
return "❌ nano-vllm is not installed. Please install it using 'cd acestep/third_parts/nano-vllm && pip install ." |
|
|
|
|
|
try: |
|
|
current_device = torch.cuda.current_device() |
|
|
device_name = torch.cuda.get_device_name(current_device) |
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
gpu_memory_utilization, low_gpu_memory_mode = self.get_gpu_memory_utilization( |
|
|
minimal_gpu=8, |
|
|
min_ratio=0.2, |
|
|
max_ratio=0.9 |
|
|
) |
|
|
if low_gpu_memory_mode: |
|
|
self.max_model_len = 2048 |
|
|
else: |
|
|
self.max_model_len = 4096 |
|
|
|
|
|
logger.info(f"Initializing 5Hz LM with model: {model_path}, enforce_eager: False, tensor_parallel_size: 1, max_model_len: {self.max_model_len}, gpu_memory_utilization: {gpu_memory_utilization}") |
|
|
start_time = time.time() |
|
|
self.llm = LLM( |
|
|
model=model_path, |
|
|
enforce_eager=False, |
|
|
tensor_parallel_size=1, |
|
|
max_model_len=self.max_model_len, |
|
|
gpu_memory_utilization=gpu_memory_utilization, |
|
|
tokenizer=self.llm_tokenizer, |
|
|
) |
|
|
logger.info(f"5Hz LM initialized successfully in {time.time() - start_time:.2f} seconds") |
|
|
self.llm_initialized = True |
|
|
self.llm_backend = "vllm" |
|
|
return f"✅ 5Hz LM initialized successfully\nModel: {model_path}\nDevice: {device_name}\nGPU Memory Utilization: {gpu_memory_utilization:.2f}" |
|
|
except Exception as e: |
|
|
self.llm_initialized = False |
|
|
return f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}" |
|
|
|
|
|
def _run_vllm( |
|
|
self, |
|
|
formatted_prompts: Union[str, List[str]], |
|
|
temperature: float, |
|
|
cfg_scale: float, |
|
|
negative_prompt: str, |
|
|
top_k: Optional[int], |
|
|
top_p: Optional[float], |
|
|
repetition_penalty: float, |
|
|
use_constrained_decoding: bool = True, |
|
|
constrained_decoding_debug: bool = False, |
|
|
metadata_temperature: Optional[float] = None, |
|
|
codes_temperature: Optional[float] = None, |
|
|
target_duration: Optional[float] = None, |
|
|
user_metadata: Optional[Dict[str, Optional[str]]] = None, |
|
|
stop_at_reasoning: bool = False, |
|
|
skip_genres: bool = True, |
|
|
skip_caption: bool = False, |
|
|
skip_language: bool = False, |
|
|
generation_phase: str = "cot", |
|
|
caption: str = "", |
|
|
lyrics: str = "", |
|
|
cot_text: str = "", |
|
|
seeds: Optional[List[int]] = None, |
|
|
) -> Union[str, List[str]]: |
|
|
""" |
|
|
Unified vllm generation function supporting both single and batch modes. |
|
|
Accepts either a single formatted prompt (str) or a list of formatted prompts (List[str]). |
|
|
Returns a single string for single mode, or a list of strings for batch mode. |
|
|
""" |
|
|
from nanovllm import SamplingParams |
|
|
|
|
|
|
|
|
formatted_prompt_list, is_batch = self._normalize_batch_input(formatted_prompts) |
|
|
batch_size = len(formatted_prompt_list) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
use_phase_temperatures = not is_batch and (metadata_temperature is not None or codes_temperature is not None) |
|
|
effective_sampler_temp = 1.0 if use_phase_temperatures else temperature |
|
|
|
|
|
|
|
|
constrained_processor = self._setup_constrained_processor( |
|
|
use_constrained_decoding=use_constrained_decoding or use_phase_temperatures, |
|
|
constrained_decoding_debug=constrained_decoding_debug, |
|
|
target_duration=target_duration, |
|
|
user_metadata=user_metadata, |
|
|
stop_at_reasoning=stop_at_reasoning, |
|
|
skip_genres=skip_genres, |
|
|
skip_caption=skip_caption, |
|
|
skip_language=skip_language, |
|
|
generation_phase=generation_phase, |
|
|
is_batch=is_batch, |
|
|
metadata_temperature=metadata_temperature, |
|
|
codes_temperature=codes_temperature, |
|
|
) |
|
|
|
|
|
sampling_params = SamplingParams( |
|
|
max_tokens=self.max_model_len - 64, |
|
|
temperature=effective_sampler_temp, |
|
|
cfg_scale=cfg_scale, |
|
|
top_k=top_k, |
|
|
top_p=top_p, |
|
|
repetition_penalty=repetition_penalty, |
|
|
logits_processor=constrained_processor, |
|
|
logits_processor_update_state=constrained_processor.update_state if constrained_processor else None, |
|
|
) |
|
|
|
|
|
if cfg_scale > 1.0: |
|
|
|
|
|
formatted_unconditional_prompt = self._build_unconditional_prompt( |
|
|
caption=caption, |
|
|
lyrics=lyrics, |
|
|
cot_text=cot_text, |
|
|
negative_prompt=negative_prompt, |
|
|
generation_phase=generation_phase, |
|
|
is_batch=is_batch, |
|
|
) |
|
|
unconditional_prompts = [formatted_unconditional_prompt] * batch_size |
|
|
|
|
|
outputs = self.llm.generate( |
|
|
formatted_prompt_list, |
|
|
sampling_params, |
|
|
unconditional_prompts=unconditional_prompts, |
|
|
) |
|
|
else: |
|
|
outputs = self.llm.generate(formatted_prompt_list, sampling_params) |
|
|
|
|
|
|
|
|
output_texts = [] |
|
|
for output in outputs: |
|
|
if hasattr(output, "outputs") and len(output.outputs) > 0: |
|
|
output_texts.append(output.outputs[0].text) |
|
|
elif hasattr(output, "text"): |
|
|
output_texts.append(output.text) |
|
|
elif isinstance(output, dict) and "text" in output: |
|
|
output_texts.append(output["text"]) |
|
|
else: |
|
|
output_texts.append(str(output)) |
|
|
|
|
|
|
|
|
return output_texts[0] if not is_batch else output_texts |
|
|
|
|
|
def _run_pt_single( |
|
|
self, |
|
|
formatted_prompt: str, |
|
|
temperature: float, |
|
|
cfg_scale: float, |
|
|
negative_prompt: str, |
|
|
top_k: Optional[int], |
|
|
top_p: Optional[float], |
|
|
repetition_penalty: float, |
|
|
use_constrained_decoding: bool, |
|
|
constrained_decoding_debug: bool, |
|
|
target_duration: Optional[float], |
|
|
user_metadata: Optional[Dict[str, Optional[str]]], |
|
|
stop_at_reasoning: bool, |
|
|
skip_genres: bool, |
|
|
skip_caption: bool, |
|
|
skip_language: bool, |
|
|
generation_phase: str, |
|
|
caption: str, |
|
|
lyrics: str, |
|
|
cot_text: str, |
|
|
) -> str: |
|
|
"""Internal helper function for single-item PyTorch generation.""" |
|
|
inputs = self.llm_tokenizer( |
|
|
formatted_prompt, |
|
|
return_tensors="pt", |
|
|
padding=False, |
|
|
truncation=True, |
|
|
) |
|
|
|
|
|
|
|
|
constrained_processor = self._setup_constrained_processor( |
|
|
use_constrained_decoding=use_constrained_decoding, |
|
|
constrained_decoding_debug=constrained_decoding_debug, |
|
|
target_duration=target_duration, |
|
|
user_metadata=user_metadata, |
|
|
stop_at_reasoning=stop_at_reasoning, |
|
|
skip_genres=skip_genres, |
|
|
skip_caption=skip_caption, |
|
|
skip_language=skip_language, |
|
|
generation_phase=generation_phase, |
|
|
is_batch=False, |
|
|
) |
|
|
|
|
|
with self._load_model_context(): |
|
|
inputs = {k: v.to(self.device) for k, v in inputs.items()} |
|
|
max_new_tokens = getattr(self.llm.config, "max_new_tokens", 4096) |
|
|
if hasattr(self, "max_model_len"): |
|
|
max_new_tokens = min(max_new_tokens, self.max_model_len - 64) |
|
|
|
|
|
|
|
|
logits_processor = self._build_logits_processor(repetition_penalty) |
|
|
|
|
|
if cfg_scale > 1.0: |
|
|
|
|
|
formatted_unconditional_prompt = self._build_unconditional_prompt( |
|
|
caption=caption, |
|
|
lyrics=lyrics, |
|
|
cot_text=cot_text, |
|
|
negative_prompt=negative_prompt, |
|
|
generation_phase=generation_phase, |
|
|
is_batch=False, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
batch_texts = [formatted_prompt, formatted_unconditional_prompt] |
|
|
original_padding_side = self.llm_tokenizer.padding_side |
|
|
self.llm_tokenizer.padding_side = 'left' |
|
|
batch_inputs_tokenized = self.llm_tokenizer( |
|
|
batch_texts, |
|
|
return_tensors="pt", |
|
|
padding=True, |
|
|
truncation=True, |
|
|
) |
|
|
self.llm_tokenizer.padding_side = original_padding_side |
|
|
batch_inputs_tokenized = {k: v.to(self.device) for k, v in batch_inputs_tokenized.items()} |
|
|
|
|
|
|
|
|
batch_input_ids = batch_inputs_tokenized['input_ids'] |
|
|
batch_attention_mask = batch_inputs_tokenized.get('attention_mask', None) |
|
|
|
|
|
|
|
|
outputs = self._generate_with_cfg_custom( |
|
|
batch_input_ids=batch_input_ids, |
|
|
batch_attention_mask=batch_attention_mask, |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=temperature, |
|
|
cfg_scale=cfg_scale, |
|
|
top_k=top_k, |
|
|
top_p=top_p, |
|
|
repetition_penalty=repetition_penalty, |
|
|
pad_token_id=self.llm_tokenizer.pad_token_id or self.llm_tokenizer.eos_token_id, |
|
|
streamer=None, |
|
|
constrained_processor=constrained_processor, |
|
|
) |
|
|
|
|
|
|
|
|
outputs = outputs[0:1] |
|
|
elif use_constrained_decoding: |
|
|
|
|
|
outputs = self._generate_with_constrained_decoding( |
|
|
input_ids=inputs["input_ids"], |
|
|
attention_mask=inputs.get("attention_mask"), |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=temperature, |
|
|
top_k=top_k, |
|
|
top_p=top_p, |
|
|
repetition_penalty=repetition_penalty, |
|
|
pad_token_id=self.llm_tokenizer.pad_token_id or self.llm_tokenizer.eos_token_id, |
|
|
streamer=None, |
|
|
constrained_processor=constrained_processor, |
|
|
) |
|
|
else: |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.llm.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=temperature if temperature > 0 else 1.0, |
|
|
do_sample=True if temperature > 0 else False, |
|
|
top_k=top_k if top_k is not None and top_k > 0 else None, |
|
|
top_p=top_p if top_p is not None and 0.0 < top_p < 1.0 else None, |
|
|
logits_processor=logits_processor if len(logits_processor) > 0 else None, |
|
|
pad_token_id=self.llm_tokenizer.pad_token_id or self.llm_tokenizer.eos_token_id, |
|
|
streamer=None, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(outputs, torch.Tensor): |
|
|
if outputs.dim() == 2: |
|
|
generated_ids = outputs[0] |
|
|
else: |
|
|
generated_ids = outputs |
|
|
else: |
|
|
generated_ids = outputs[0] |
|
|
|
|
|
|
|
|
|
|
|
if cfg_scale > 1.0: |
|
|
|
|
|
|
|
|
input_length = batch_inputs_tokenized['input_ids'].shape[1] |
|
|
else: |
|
|
input_length = inputs["input_ids"].shape[1] |
|
|
|
|
|
generated_ids = generated_ids[input_length:] |
|
|
|
|
|
|
|
|
if generated_ids.is_cuda: |
|
|
generated_ids = generated_ids.cpu() |
|
|
|
|
|
output_text = self.llm_tokenizer.decode(generated_ids, skip_special_tokens=False) |
|
|
return output_text |
|
|
|
|
|
def _run_pt( |
|
|
self, |
|
|
formatted_prompts: Union[str, List[str]], |
|
|
temperature: float, |
|
|
cfg_scale: float, |
|
|
negative_prompt: str, |
|
|
top_k: Optional[int], |
|
|
top_p: Optional[float], |
|
|
repetition_penalty: float, |
|
|
use_constrained_decoding: bool = True, |
|
|
constrained_decoding_debug: bool = False, |
|
|
target_duration: Optional[float] = None, |
|
|
user_metadata: Optional[Dict[str, Optional[str]]] = None, |
|
|
stop_at_reasoning: bool = False, |
|
|
skip_genres: bool = True, |
|
|
skip_caption: bool = False, |
|
|
skip_language: bool = False, |
|
|
generation_phase: str = "cot", |
|
|
caption: str = "", |
|
|
lyrics: str = "", |
|
|
cot_text: str = "", |
|
|
seeds: Optional[List[int]] = None, |
|
|
) -> Union[str, List[str]]: |
|
|
""" |
|
|
Unified PyTorch generation function supporting both single and batch modes. |
|
|
Accepts either a single formatted prompt (str) or a list of formatted prompts (List[str]). |
|
|
Returns a single string for single mode, or a list of strings for batch mode. |
|
|
Note: PyTorch backend processes batch items sequentially (doesn't support true batching efficiently). |
|
|
""" |
|
|
|
|
|
formatted_prompt_list, is_batch = self._normalize_batch_input(formatted_prompts) |
|
|
|
|
|
|
|
|
if is_batch: |
|
|
output_texts = [] |
|
|
for i, formatted_prompt in enumerate(formatted_prompt_list): |
|
|
|
|
|
if seeds and i < len(seeds): |
|
|
torch.manual_seed(seeds[i]) |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.manual_seed_all(seeds[i]) |
|
|
|
|
|
|
|
|
output_text = self._run_pt_single( |
|
|
formatted_prompt=formatted_prompt, |
|
|
temperature=temperature, |
|
|
cfg_scale=cfg_scale, |
|
|
negative_prompt=negative_prompt, |
|
|
top_k=top_k, |
|
|
top_p=top_p, |
|
|
repetition_penalty=repetition_penalty, |
|
|
use_constrained_decoding=use_constrained_decoding, |
|
|
constrained_decoding_debug=constrained_decoding_debug, |
|
|
target_duration=target_duration, |
|
|
user_metadata=None, |
|
|
stop_at_reasoning=False, |
|
|
skip_genres=True, |
|
|
skip_caption=True, |
|
|
skip_language=True, |
|
|
generation_phase=generation_phase, |
|
|
caption=caption, |
|
|
lyrics=lyrics, |
|
|
cot_text=cot_text, |
|
|
) |
|
|
|
|
|
output_texts.append(output_text) |
|
|
|
|
|
return output_texts |
|
|
|
|
|
|
|
|
formatted_prompt = formatted_prompt_list[0] |
|
|
|
|
|
return self._run_pt_single( |
|
|
formatted_prompt=formatted_prompt, |
|
|
temperature=temperature, |
|
|
cfg_scale=cfg_scale, |
|
|
negative_prompt=negative_prompt, |
|
|
top_k=top_k, |
|
|
top_p=top_p, |
|
|
repetition_penalty=repetition_penalty, |
|
|
use_constrained_decoding=use_constrained_decoding, |
|
|
constrained_decoding_debug=constrained_decoding_debug, |
|
|
target_duration=target_duration, |
|
|
user_metadata=user_metadata, |
|
|
stop_at_reasoning=stop_at_reasoning, |
|
|
skip_genres=skip_genres, |
|
|
skip_caption=skip_caption, |
|
|
skip_language=skip_language, |
|
|
generation_phase=generation_phase, |
|
|
caption=caption, |
|
|
lyrics=lyrics, |
|
|
cot_text=cot_text, |
|
|
) |
|
|
|
|
|
def has_all_metas(self, user_metadata: Optional[Dict[str, Optional[str]]]) -> bool: |
|
|
"""Check if all required metadata are present.""" |
|
|
if user_metadata is None: |
|
|
return False |
|
|
if 'bpm' in user_metadata and 'keyscale' in user_metadata and 'timesignature' in user_metadata and 'duration' in user_metadata: |
|
|
return True |
|
|
return False |
|
|
|
|
|
def _format_metadata_as_cot(self, metadata: Dict[str, Any]) -> str: |
|
|
""" |
|
|
Format parsed metadata as CoT text using YAML format (matching training format). |
|
|
|
|
|
Args: |
|
|
metadata: Dictionary with keys: bpm, caption, duration, keyscale, language, timesignature |
|
|
|
|
|
Returns: |
|
|
Formatted CoT text: "<think>\n{yaml_content}\n</think>" |
|
|
""" |
|
|
|
|
|
cot_items = {} |
|
|
for key in ['bpm', 'caption', 'duration', 'keyscale', 'language', 'timesignature']: |
|
|
if key in metadata and metadata[key] is not None: |
|
|
value = metadata[key] |
|
|
if key == "timesignature" and value.endswith("/4"): |
|
|
value = value.split("/")[0] |
|
|
if isinstance(value, str) and value.isdigit(): |
|
|
value = int(value) |
|
|
cot_items[key] = value |
|
|
|
|
|
|
|
|
if len(cot_items) > 0: |
|
|
cot_yaml = yaml.dump(cot_items, allow_unicode=True, sort_keys=True).strip() |
|
|
else: |
|
|
cot_yaml = "" |
|
|
|
|
|
return f"<think>\n{cot_yaml}\n</think>" |
|
|
|
|
|
def generate_with_stop_condition( |
|
|
self, |
|
|
caption: str, |
|
|
lyrics: str, |
|
|
infer_type: str, |
|
|
temperature: float = 0.85, |
|
|
cfg_scale: float = 1.0, |
|
|
negative_prompt: str = "NO USER INPUT", |
|
|
top_k: Optional[int] = None, |
|
|
top_p: Optional[float] = None, |
|
|
repetition_penalty: float = 1.0, |
|
|
use_constrained_decoding: bool = True, |
|
|
constrained_decoding_debug: bool = False, |
|
|
target_duration: Optional[float] = None, |
|
|
user_metadata: Optional[Dict[str, Optional[str]]] = None, |
|
|
use_cot_metas: bool = True, |
|
|
use_cot_caption: bool = True, |
|
|
use_cot_language: bool = True, |
|
|
batch_size: Optional[int] = None, |
|
|
seeds: Optional[List[int]] = None, |
|
|
progress=None, |
|
|
) -> Dict[str, Any]: |
|
|
"""Two-phase LM generation: CoT generation followed by audio codes generation. |
|
|
|
|
|
- infer_type='dit': Phase 1 only - generate CoT and return metas (no audio codes) |
|
|
- infer_type='llm_dit': Phase 1 + Phase 2 - generate CoT then audio codes |
|
|
|
|
|
Args: |
|
|
target_duration: Target duration in seconds for codes generation constraint. |
|
|
5 codes = 1 second. If specified, blocks EOS until target reached. |
|
|
user_metadata: User-provided metadata fields (e.g. bpm/duration/keyscale/timesignature). |
|
|
If specified, constrained decoding will inject these values directly. |
|
|
use_cot_caption: Whether to generate caption in CoT (default True). |
|
|
use_cot_language: Whether to generate language in CoT (default True). |
|
|
batch_size: Optional batch size for batch generation. If None or 1, returns single result. |
|
|
If > 1, returns batch results (lists). |
|
|
seeds: Optional list of seeds for batch generation (for reproducibility). |
|
|
Only used when batch_size > 1. TODO: not used yet |
|
|
|
|
|
Returns: |
|
|
Dictionary containing: |
|
|
- metadata: Dict or List[Dict] - Generated metadata |
|
|
- audio_codes: str or List[str] - Generated audio codes |
|
|
- success: bool - Whether generation succeeded |
|
|
- error: Optional[str] - Error message if failed |
|
|
- extra_outputs: Dict with time_costs and other info |
|
|
""" |
|
|
if progress is None: |
|
|
def progress(*args, **kwargs): |
|
|
pass |
|
|
|
|
|
infer_type = (infer_type or "").strip().lower() |
|
|
if infer_type not in {"dit", "llm_dit"}: |
|
|
error_msg = f"invalid infer_type: {infer_type!r} (expected 'dit' or 'llm_dit')" |
|
|
return { |
|
|
"metadata": [] if (batch_size and batch_size > 1) else {}, |
|
|
"audio_codes": [] if (batch_size and batch_size > 1) else "", |
|
|
"success": False, |
|
|
"error": error_msg, |
|
|
"extra_outputs": {"time_costs": {}}, |
|
|
} |
|
|
|
|
|
|
|
|
is_batch = batch_size and batch_size > 1 |
|
|
actual_batch_size = batch_size if is_batch else 1 |
|
|
|
|
|
|
|
|
metadata = {} |
|
|
audio_codes = "" |
|
|
has_all_metas = self.has_all_metas(user_metadata) |
|
|
phase1_time = 0.0 |
|
|
phase2_time = 0.0 |
|
|
|
|
|
|
|
|
if is_batch: |
|
|
if seeds is None: |
|
|
seeds = [random.randint(0, 2**32 - 1) for _ in range(actual_batch_size)] |
|
|
elif len(seeds) < actual_batch_size: |
|
|
seeds = list(seeds) + [random.randint(0, 2**32 - 1) for _ in range(actual_batch_size - len(seeds))] |
|
|
else: |
|
|
seeds = seeds[:actual_batch_size] |
|
|
|
|
|
|
|
|
|
|
|
progress(0.1, f"Phase 1: Generating CoT metadata (once for all items)...") |
|
|
if not has_all_metas and use_cot_metas: |
|
|
if is_batch: |
|
|
logger.info("Batch Phase 1: Generating CoT metadata (once for all items)...") |
|
|
else: |
|
|
logger.info("Phase 1: Generating CoT metadata...") |
|
|
phase1_start = time.time() |
|
|
|
|
|
|
|
|
formatted_prompt = self.build_formatted_prompt(caption, lyrics, generation_phase="cot") |
|
|
|
|
|
logger.info(f"generate_with_stop_condition: formatted_prompt={formatted_prompt}") |
|
|
|
|
|
cot_output_text, status = self.generate_from_formatted_prompt( |
|
|
formatted_prompt=formatted_prompt, |
|
|
cfg={ |
|
|
"temperature": temperature, |
|
|
"cfg_scale": cfg_scale, |
|
|
"negative_prompt": negative_prompt, |
|
|
"top_k": top_k, |
|
|
"top_p": top_p, |
|
|
"repetition_penalty": repetition_penalty, |
|
|
"target_duration": None, |
|
|
"user_metadata": user_metadata, |
|
|
"skip_caption": not use_cot_caption, |
|
|
"skip_language": not use_cot_language, |
|
|
"skip_genres": True, |
|
|
"generation_phase": "cot", |
|
|
|
|
|
"caption": caption, |
|
|
"lyrics": lyrics, |
|
|
}, |
|
|
use_constrained_decoding=use_constrained_decoding, |
|
|
constrained_decoding_debug=constrained_decoding_debug, |
|
|
stop_at_reasoning=True, |
|
|
) |
|
|
|
|
|
phase1_time = time.time() - phase1_start |
|
|
|
|
|
if not cot_output_text: |
|
|
return { |
|
|
"metadata": [] if is_batch else {}, |
|
|
"audio_codes": [] if is_batch else "", |
|
|
"success": False, |
|
|
"error": status, |
|
|
"extra_outputs": {"time_costs": {"phase1_time": phase1_time}}, |
|
|
} |
|
|
|
|
|
|
|
|
metadata, _ = self.parse_lm_output(cot_output_text) |
|
|
if is_batch: |
|
|
logger.info(f"Batch Phase 1 completed in {phase1_time:.2f}s. Generated metadata: {list(metadata.keys())}") |
|
|
else: |
|
|
logger.info(f"Phase 1 completed in {phase1_time:.2f}s. Generated metadata: {list(metadata.keys())}") |
|
|
else: |
|
|
|
|
|
if is_batch: |
|
|
logger.info("Batch Phase 1: Using user-provided metadata (skipping generation)") |
|
|
else: |
|
|
logger.info("Phase 1: Using user-provided metadata (skipping generation)") |
|
|
metadata = {k: v for k, v in user_metadata.items() if v is not None} |
|
|
|
|
|
|
|
|
if infer_type == "dit": |
|
|
if is_batch: |
|
|
metadata_list = [metadata.copy() for _ in range(actual_batch_size)] |
|
|
return { |
|
|
"metadata": metadata_list, |
|
|
"audio_codes": [""] * actual_batch_size, |
|
|
"success": True, |
|
|
"error": None, |
|
|
"extra_outputs": { |
|
|
"time_costs": { |
|
|
"phase1_time": phase1_time, |
|
|
"total_time": phase1_time, |
|
|
} |
|
|
}, |
|
|
} |
|
|
else: |
|
|
return { |
|
|
"metadata": metadata, |
|
|
"audio_codes": "", |
|
|
"success": True, |
|
|
"error": None, |
|
|
"extra_outputs": { |
|
|
"time_costs": { |
|
|
"phase1_time": phase1_time, |
|
|
"total_time": phase1_time, |
|
|
} |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
if is_batch: |
|
|
logger.info(f"Batch Phase 2: Generating audio codes for {actual_batch_size} items...") |
|
|
else: |
|
|
logger.info("Phase 2: Generating audio codes...") |
|
|
phase2_start = time.time() |
|
|
|
|
|
|
|
|
cot_text = self._format_metadata_as_cot(metadata) |
|
|
|
|
|
|
|
|
formatted_prompt_with_cot = self.build_formatted_prompt_with_cot(caption, lyrics, cot_text) |
|
|
logger.info(f"generate_with_stop_condition: formatted_prompt_with_cot={formatted_prompt_with_cot}") |
|
|
|
|
|
progress(0.5, f"Phase 2: Generating audio codes for {actual_batch_size} items...") |
|
|
if is_batch: |
|
|
|
|
|
formatted_prompts = [formatted_prompt_with_cot] * actual_batch_size |
|
|
|
|
|
|
|
|
try: |
|
|
if self.llm_backend == "vllm": |
|
|
codes_outputs = self._run_vllm( |
|
|
formatted_prompts=formatted_prompts, |
|
|
temperature=temperature, |
|
|
cfg_scale=cfg_scale, |
|
|
negative_prompt=negative_prompt, |
|
|
top_k=top_k, |
|
|
top_p=top_p, |
|
|
repetition_penalty=repetition_penalty, |
|
|
use_constrained_decoding=use_constrained_decoding, |
|
|
constrained_decoding_debug=constrained_decoding_debug, |
|
|
target_duration=target_duration, |
|
|
generation_phase="codes", |
|
|
caption=caption, |
|
|
lyrics=lyrics, |
|
|
cot_text=cot_text, |
|
|
seeds=seeds, |
|
|
) |
|
|
else: |
|
|
codes_outputs = self._run_pt( |
|
|
formatted_prompts=formatted_prompts, |
|
|
temperature=temperature, |
|
|
cfg_scale=cfg_scale, |
|
|
negative_prompt=negative_prompt, |
|
|
top_k=top_k, |
|
|
top_p=top_p, |
|
|
repetition_penalty=repetition_penalty, |
|
|
use_constrained_decoding=use_constrained_decoding, |
|
|
constrained_decoding_debug=constrained_decoding_debug, |
|
|
target_duration=target_duration, |
|
|
generation_phase="codes", |
|
|
caption=caption, |
|
|
lyrics=lyrics, |
|
|
cot_text=cot_text, |
|
|
seeds=seeds, |
|
|
) |
|
|
except Exception as e: |
|
|
error_msg = f"Error in batch codes generation: {str(e)}" |
|
|
logger.error(error_msg) |
|
|
return { |
|
|
"metadata": [], |
|
|
"audio_codes": [], |
|
|
"success": False, |
|
|
"error": error_msg, |
|
|
"extra_outputs": { |
|
|
"time_costs": { |
|
|
"phase1_time": phase1_time, |
|
|
"phase2_time": 0.0, |
|
|
"total_time": phase1_time, |
|
|
} |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
audio_codes_list = [] |
|
|
metadata_list = [] |
|
|
for output_text in codes_outputs: |
|
|
_, audio_codes_item = self.parse_lm_output(output_text) |
|
|
audio_codes_list.append(audio_codes_item) |
|
|
metadata_list.append(metadata.copy()) |
|
|
|
|
|
phase2_time = time.time() - phase2_start |
|
|
|
|
|
|
|
|
codes_counts = [len(codes.split('<|audio_code_')) - 1 if codes else 0 for codes in audio_codes_list] |
|
|
logger.info(f"Batch Phase 2 completed in {phase2_time:.2f}s. Generated codes: {codes_counts}") |
|
|
|
|
|
total_time = phase1_time + phase2_time |
|
|
return { |
|
|
"metadata": metadata_list, |
|
|
"audio_codes": audio_codes_list, |
|
|
"success": True, |
|
|
"error": None, |
|
|
"extra_outputs": { |
|
|
"time_costs": { |
|
|
"phase1_time": phase1_time, |
|
|
"phase2_time": phase2_time, |
|
|
"total_time": total_time, |
|
|
}, |
|
|
"codes_counts": codes_counts, |
|
|
"total_codes": sum(codes_counts), |
|
|
}, |
|
|
} |
|
|
else: |
|
|
|
|
|
codes_output_text, status = self.generate_from_formatted_prompt( |
|
|
formatted_prompt=formatted_prompt_with_cot, |
|
|
cfg={ |
|
|
"temperature": temperature, |
|
|
"cfg_scale": cfg_scale, |
|
|
"negative_prompt": negative_prompt, |
|
|
"top_k": top_k, |
|
|
"top_p": top_p, |
|
|
"repetition_penalty": repetition_penalty, |
|
|
"target_duration": target_duration, |
|
|
"user_metadata": None, |
|
|
"skip_caption": True, |
|
|
"skip_language": True, |
|
|
"generation_phase": "codes", |
|
|
|
|
|
"caption": caption, |
|
|
"lyrics": lyrics, |
|
|
"cot_text": cot_text, |
|
|
}, |
|
|
use_constrained_decoding=use_constrained_decoding, |
|
|
constrained_decoding_debug=constrained_decoding_debug, |
|
|
stop_at_reasoning=False, |
|
|
) |
|
|
|
|
|
if not codes_output_text: |
|
|
total_time = phase1_time + phase2_time |
|
|
return { |
|
|
"metadata": metadata, |
|
|
"audio_codes": "", |
|
|
"success": False, |
|
|
"error": status, |
|
|
"extra_outputs": { |
|
|
"time_costs": { |
|
|
"phase1_time": phase1_time, |
|
|
"phase2_time": phase2_time, |
|
|
"total_time": total_time, |
|
|
} |
|
|
}, |
|
|
} |
|
|
|
|
|
phase2_time = time.time() - phase2_start |
|
|
|
|
|
|
|
|
_, audio_codes = self.parse_lm_output(codes_output_text) |
|
|
|
|
|
codes_count = len(audio_codes.split('<|audio_code_')) - 1 if audio_codes else 0 |
|
|
logger.info(f"Phase 2 completed in {phase2_time:.2f}s. Generated {codes_count} audio codes") |
|
|
|
|
|
total_time = phase1_time + phase2_time |
|
|
return { |
|
|
"metadata": metadata, |
|
|
"audio_codes": audio_codes, |
|
|
"success": True, |
|
|
"error": None, |
|
|
"extra_outputs": { |
|
|
"time_costs": { |
|
|
"phase1_time": phase1_time, |
|
|
"phase2_time": phase2_time, |
|
|
"total_time": total_time, |
|
|
}, |
|
|
"codes_count": codes_count, |
|
|
}, |
|
|
} |
|
|
|
|
|
def build_formatted_prompt(self, caption: str, lyrics: str = "", is_negative_prompt: bool = False, generation_phase: str = "cot", negative_prompt: str = "NO USER INPUT") -> str: |
|
|
""" |
|
|
Build the chat-formatted prompt for 5Hz LM from caption/lyrics. |
|
|
Raises a ValueError if the tokenizer is not initialized. |
|
|
|
|
|
Args: |
|
|
caption: Caption text |
|
|
lyrics: Lyrics text |
|
|
is_negative_prompt: If True, builds unconditional prompt for CFG |
|
|
generation_phase: "cot" or "codes" - affects unconditional prompt format |
|
|
negative_prompt: Negative prompt for CFG (used when is_negative_prompt=True) |
|
|
|
|
|
Example: |
|
|
prompt = handler.build_formatted_prompt("calm piano", "hello world") |
|
|
""" |
|
|
if self.llm_tokenizer is None: |
|
|
raise ValueError("LLM tokenizer is not initialized. Call initialize() first.") |
|
|
|
|
|
if is_negative_prompt: |
|
|
|
|
|
|
|
|
has_negative_prompt = self._has_meaningful_negative_prompt(negative_prompt) |
|
|
|
|
|
if generation_phase == "cot": |
|
|
|
|
|
if has_negative_prompt: |
|
|
|
|
|
prompt = f"# Caption\n{negative_prompt}\n\n# Lyric\n{lyrics}\n" |
|
|
else: |
|
|
|
|
|
prompt = f"# Lyric\n{lyrics}\n" |
|
|
else: |
|
|
|
|
|
|
|
|
prompt = caption |
|
|
else: |
|
|
|
|
|
prompt = f"# Caption\n{caption}\n\n# Lyric\n{lyrics}\n" |
|
|
|
|
|
return self.llm_tokenizer.apply_chat_template( |
|
|
[ |
|
|
{"role": "system", "content": f"# Instruction\n{DEFAULT_LM_INSTRUCTION}\n\n"}, |
|
|
{"role": "user", "content": prompt}, |
|
|
], |
|
|
tokenize=False, |
|
|
add_generation_prompt=True, |
|
|
) |
|
|
|
|
|
def build_formatted_prompt_with_cot(self, caption: str, lyrics: str, cot_text: str, is_negative_prompt: bool = False, negative_prompt: str = "NO USER INPUT") -> str: |
|
|
""" |
|
|
Build the chat-formatted prompt for codes generation phase with pre-generated CoT. |
|
|
|
|
|
Args: |
|
|
caption: Caption text |
|
|
lyrics: Lyrics text |
|
|
cot_text: Pre-generated CoT text (e.g., "<think>\\nbpm: 120\\n...\\n</think>") |
|
|
is_negative_prompt: If True, uses empty CoT for CFG unconditional prompt |
|
|
negative_prompt: Negative prompt for CFG (used when is_negative_prompt=True) |
|
|
|
|
|
Returns: |
|
|
Formatted prompt string |
|
|
|
|
|
Example: |
|
|
cot = "<think>\\nbpm: 120\\ncaption: calm piano\\n...\\n</think>" |
|
|
prompt = handler.build_formatted_prompt_with_cot("calm piano", "hello", cot) |
|
|
""" |
|
|
if self.llm_tokenizer is None: |
|
|
raise ValueError("LLM tokenizer is not initialized. Call initialize() first.") |
|
|
|
|
|
if is_negative_prompt: |
|
|
|
|
|
|
|
|
has_negative_prompt = self._has_meaningful_negative_prompt(negative_prompt) |
|
|
|
|
|
|
|
|
cot_for_prompt = "<think>\n</think>" |
|
|
|
|
|
if has_negative_prompt: |
|
|
|
|
|
caption_for_prompt = negative_prompt |
|
|
else: |
|
|
|
|
|
caption_for_prompt = caption |
|
|
else: |
|
|
|
|
|
cot_for_prompt = cot_text |
|
|
caption_for_prompt = caption |
|
|
|
|
|
|
|
|
|
|
|
user_prompt = f"# Caption\n{caption_for_prompt}\n\n# Lyric\n{lyrics}\n" |
|
|
|
|
|
|
|
|
|
|
|
formatted = self.llm_tokenizer.apply_chat_template( |
|
|
[ |
|
|
{"role": "system", "content": f"# Instruction\n{DEFAULT_LM_INSTRUCTION}\n\n"}, |
|
|
{"role": "user", "content": user_prompt}, |
|
|
{"role": "assistant", "content": cot_for_prompt}, |
|
|
], |
|
|
tokenize=False, |
|
|
add_generation_prompt=False, |
|
|
) |
|
|
|
|
|
|
|
|
if not formatted.endswith('\n'): |
|
|
formatted += '\n' |
|
|
|
|
|
return formatted |
|
|
|
|
|
def build_formatted_prompt_for_understanding( |
|
|
self, |
|
|
audio_codes: str, |
|
|
is_negative_prompt: bool = False, |
|
|
negative_prompt: str = "NO USER INPUT" |
|
|
) -> str: |
|
|
""" |
|
|
Build the chat-formatted prompt for audio understanding from codes. |
|
|
|
|
|
This is the reverse of generation: given audio codes, generate metadata and lyrics. |
|
|
|
|
|
Args: |
|
|
audio_codes: Audio code string (e.g., "<|audio_code_123|><|audio_code_456|>...") |
|
|
is_negative_prompt: If True, builds unconditional prompt for CFG |
|
|
negative_prompt: Negative prompt for CFG (used when is_negative_prompt=True) |
|
|
|
|
|
Returns: |
|
|
Formatted prompt string |
|
|
|
|
|
Example: |
|
|
codes = "<|audio_code_18953|><|audio_code_13833|>..." |
|
|
prompt = handler.build_formatted_prompt_for_understanding(codes) |
|
|
""" |
|
|
if self.llm_tokenizer is None: |
|
|
raise ValueError("LLM tokenizer is not initialized. Call initialize() first.") |
|
|
|
|
|
|
|
|
|
|
|
if is_negative_prompt: |
|
|
user_content = negative_prompt if negative_prompt and negative_prompt.strip() else "" |
|
|
else: |
|
|
user_content = audio_codes |
|
|
|
|
|
return self.llm_tokenizer.apply_chat_template( |
|
|
[ |
|
|
{ |
|
|
"role": "system", |
|
|
"content": f"# Instruction\n{DEFAULT_LM_UNDERSTAND_INSTRUCTION}\n\n" |
|
|
}, |
|
|
{ |
|
|
"role": "user", |
|
|
"content": user_content |
|
|
}, |
|
|
], |
|
|
tokenize=False, |
|
|
add_generation_prompt=True, |
|
|
) |
|
|
|
|
|
def understand_audio_from_codes( |
|
|
self, |
|
|
audio_codes: str, |
|
|
temperature: float = 0.3, |
|
|
top_k: Optional[int] = None, |
|
|
top_p: Optional[float] = None, |
|
|
repetition_penalty: float = 1.0, |
|
|
use_constrained_decoding: bool = True, |
|
|
constrained_decoding_debug: bool = False, |
|
|
) -> Tuple[Dict[str, Any], str]: |
|
|
""" |
|
|
Understand audio codes and generate metadata + lyrics. |
|
|
|
|
|
This is the reverse of the normal generation flow: |
|
|
- Input: Audio codes |
|
|
- Output: Metadata (bpm, caption, duration, etc.) + Lyrics |
|
|
|
|
|
Note: cfg_scale and negative_prompt are not supported in understand mode. |
|
|
|
|
|
Args: |
|
|
audio_codes: String of audio code tokens (e.g., "<|audio_code_123|><|audio_code_456|>...") |
|
|
temperature: Sampling temperature for generation |
|
|
top_k: Top-K sampling (None = disabled) |
|
|
top_p: Top-P (nucleus) sampling (None = disabled) |
|
|
repetition_penalty: Repetition penalty (1.0 = no penalty) |
|
|
use_constrained_decoding: Whether to use FSM-based constrained decoding for metadata |
|
|
constrained_decoding_debug: Whether to enable debug logging for constrained decoding |
|
|
|
|
|
Returns: |
|
|
Tuple of (metadata_dict, status_message) |
|
|
metadata_dict contains: |
|
|
- bpm: int or str |
|
|
- caption: str |
|
|
- duration: int or str |
|
|
- keyscale: str |
|
|
- language: str |
|
|
- timesignature: str |
|
|
- lyrics: str (extracted from output after </think>) |
|
|
|
|
|
Example: |
|
|
codes = "<|audio_code_18953|><|audio_code_13833|>..." |
|
|
metadata, status = handler.understand_audio_from_codes(codes) |
|
|
print(metadata['caption']) # "A cinematic orchestral piece..." |
|
|
print(metadata['lyrics']) # "[Intro: ...]\\n..." |
|
|
""" |
|
|
if not getattr(self, "llm_initialized", False): |
|
|
return {}, "❌ 5Hz LM not initialized. Please initialize it first." |
|
|
|
|
|
if not audio_codes or not audio_codes.strip(): |
|
|
return {}, "❌ No audio codes provided. Please paste audio codes first." |
|
|
|
|
|
logger.info(f"Understanding audio codes (length: {len(audio_codes)} chars)") |
|
|
|
|
|
|
|
|
formatted_prompt = self.build_formatted_prompt_for_understanding(audio_codes) |
|
|
print(f"formatted_prompt: {formatted_prompt}") |
|
|
|
|
|
|
|
|
|
|
|
output_text, status = self.generate_from_formatted_prompt( |
|
|
formatted_prompt=formatted_prompt, |
|
|
cfg={ |
|
|
"temperature": temperature, |
|
|
"top_k": top_k, |
|
|
"top_p": top_p, |
|
|
"repetition_penalty": repetition_penalty, |
|
|
"target_duration": None, |
|
|
"user_metadata": None, |
|
|
"skip_caption": False, |
|
|
"skip_language": False, |
|
|
"skip_genres": False, |
|
|
"generation_phase": "understand", |
|
|
|
|
|
"caption": "", |
|
|
"lyrics": "", |
|
|
}, |
|
|
use_constrained_decoding=use_constrained_decoding, |
|
|
constrained_decoding_debug=constrained_decoding_debug, |
|
|
stop_at_reasoning=False, |
|
|
) |
|
|
|
|
|
if not output_text: |
|
|
return {}, status |
|
|
|
|
|
|
|
|
metadata, _ = self.parse_lm_output(output_text) |
|
|
|
|
|
|
|
|
lyrics = self._extract_lyrics_from_output(output_text) |
|
|
if lyrics: |
|
|
metadata['lyrics'] = lyrics |
|
|
|
|
|
logger.info(f"Understanding completed. Generated {len(metadata)} metadata fields") |
|
|
if constrained_decoding_debug: |
|
|
logger.debug(f"Generated metadata: {list(metadata.keys())}") |
|
|
logger.debug(f"Output text preview: {output_text[:200]}...") |
|
|
|
|
|
status_msg = f"✅ Understanding completed successfully\nGenerated fields: {', '.join(metadata.keys())}" |
|
|
return metadata, status_msg |
|
|
|
|
|
def _extract_lyrics_from_output(self, output_text: str) -> str: |
|
|
""" |
|
|
Extract lyrics section from LLM output. |
|
|
|
|
|
The lyrics appear after the </think> tag and typically start with "# Lyric" |
|
|
or directly with lyric content. |
|
|
|
|
|
Args: |
|
|
output_text: Full LLM output text |
|
|
|
|
|
Returns: |
|
|
Extracted lyrics string, or empty string if no lyrics found |
|
|
""" |
|
|
import re |
|
|
|
|
|
|
|
|
think_end_pattern = r'</think>' |
|
|
match = re.search(think_end_pattern, output_text) |
|
|
|
|
|
if not match: |
|
|
|
|
|
return "" |
|
|
|
|
|
|
|
|
after_think = output_text[match.end():].strip() |
|
|
|
|
|
if not after_think: |
|
|
return "" |
|
|
|
|
|
|
|
|
lyric_header_pattern = r'^#\s*Lyri[c|cs]?\s*\n' |
|
|
after_think = re.sub(lyric_header_pattern, '', after_think, flags=re.IGNORECASE) |
|
|
|
|
|
|
|
|
after_think = re.sub(r'<\|im_end\|>\s*$', '', after_think) |
|
|
|
|
|
return after_think.strip() |
|
|
|
|
|
def build_formatted_prompt_for_inspiration( |
|
|
self, |
|
|
query: str, |
|
|
instrumental: bool = False, |
|
|
is_negative_prompt: bool = False, |
|
|
negative_prompt: str = "NO USER INPUT" |
|
|
) -> str: |
|
|
""" |
|
|
Build the chat-formatted prompt for inspiration/simple mode. |
|
|
|
|
|
This generates a complete sample (caption, lyrics, metadata) from a user's |
|
|
natural language music description query. |
|
|
|
|
|
Args: |
|
|
query: User's natural language music description |
|
|
instrumental: Whether to generate instrumental music (no vocals) |
|
|
is_negative_prompt: If True, builds unconditional prompt for CFG |
|
|
negative_prompt: Negative prompt for CFG (used when is_negative_prompt=True) |
|
|
|
|
|
Returns: |
|
|
Formatted prompt string |
|
|
|
|
|
Example: |
|
|
query = "a soft Bengali love song for a quiet evening" |
|
|
prompt = handler.build_formatted_prompt_for_inspiration(query, instrumental=False) |
|
|
""" |
|
|
if self.llm_tokenizer is None: |
|
|
raise ValueError("LLM tokenizer is not initialized. Call initialize() first.") |
|
|
|
|
|
|
|
|
instrumental_str = "true" if instrumental else "false" |
|
|
|
|
|
if is_negative_prompt: |
|
|
|
|
|
user_content = negative_prompt if negative_prompt and negative_prompt.strip() else "" |
|
|
else: |
|
|
|
|
|
user_content = f"{query}\n\ninstrumental: {instrumental_str}" |
|
|
|
|
|
return self.llm_tokenizer.apply_chat_template( |
|
|
[ |
|
|
{ |
|
|
"role": "system", |
|
|
"content": f"# Instruction\n{DEFAULT_LM_INSPIRED_INSTRUCTION}\n\n" |
|
|
}, |
|
|
{ |
|
|
"role": "user", |
|
|
"content": user_content |
|
|
}, |
|
|
], |
|
|
tokenize=False, |
|
|
add_generation_prompt=True, |
|
|
) |
|
|
|
|
|
def create_sample_from_query( |
|
|
self, |
|
|
query: str, |
|
|
instrumental: bool = False, |
|
|
vocal_language: Optional[str] = None, |
|
|
temperature: float = 0.85, |
|
|
top_k: Optional[int] = None, |
|
|
top_p: Optional[float] = None, |
|
|
repetition_penalty: float = 1.0, |
|
|
use_constrained_decoding: bool = True, |
|
|
constrained_decoding_debug: bool = False, |
|
|
) -> Tuple[Dict[str, Any], str]: |
|
|
""" |
|
|
Create a complete music sample from a user's natural language query. |
|
|
|
|
|
This is the "Simple Mode" / "Inspiration Mode" feature that generates: |
|
|
- Metadata (bpm, caption, duration, keyscale, language, timesignature) |
|
|
- Lyrics (unless instrumental=True) |
|
|
|
|
|
Args: |
|
|
query: User's natural language music description |
|
|
instrumental: Whether to generate instrumental music (no vocals) |
|
|
vocal_language: Allowed vocal language for constrained decoding (e.g., "en", "zh"). |
|
|
If provided and not "unknown", it will be used. |
|
|
temperature: Sampling temperature for generation (0.0-2.0) |
|
|
top_k: Top-K sampling (None = disabled) |
|
|
top_p: Top-P (nucleus) sampling (None = disabled) |
|
|
repetition_penalty: Repetition penalty (1.0 = no penalty) |
|
|
use_constrained_decoding: Whether to use FSM-based constrained decoding |
|
|
constrained_decoding_debug: Whether to enable debug logging |
|
|
|
|
|
Returns: |
|
|
Tuple of (metadata_dict, status_message) |
|
|
metadata_dict contains: |
|
|
- bpm: int or str |
|
|
- caption: str |
|
|
- duration: int or str |
|
|
- keyscale: str |
|
|
- language: str |
|
|
- timesignature: str |
|
|
- lyrics: str (extracted from output after </think>) |
|
|
- instrumental: bool (echoed back) |
|
|
|
|
|
Example: |
|
|
query = "a soft Bengali love song for a quiet evening" |
|
|
metadata, status = handler.create_sample_from_query(query, instrumental=False, vocal_language="bn") |
|
|
print(metadata['caption']) # "A gentle romantic acoustic pop ballad..." |
|
|
print(metadata['lyrics']) # "[Intro: ...]\\n..." |
|
|
""" |
|
|
if not getattr(self, "llm_initialized", False): |
|
|
return {}, "❌ 5Hz LM not initialized. Please initialize it first." |
|
|
|
|
|
if not query or not query.strip(): |
|
|
query = "NO USER INPUT" |
|
|
|
|
|
logger.info(f"Creating sample from query: {query[:100]}... (instrumental={instrumental}, vocal_language={vocal_language})") |
|
|
|
|
|
|
|
|
formatted_prompt = self.build_formatted_prompt_for_inspiration( |
|
|
query=query, |
|
|
instrumental=instrumental, |
|
|
) |
|
|
logger.debug(f"Formatted prompt for inspiration: {formatted_prompt}") |
|
|
|
|
|
|
|
|
user_metadata = None |
|
|
skip_language = False |
|
|
if vocal_language and vocal_language.strip() and vocal_language.strip().lower() != "unknown": |
|
|
|
|
|
user_metadata = {"language": vocal_language.strip()} |
|
|
|
|
|
logger.info(f"Using user-specified language: {vocal_language.strip()}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output_text, status = self.generate_from_formatted_prompt( |
|
|
formatted_prompt=formatted_prompt, |
|
|
cfg={ |
|
|
"temperature": temperature, |
|
|
"top_k": top_k, |
|
|
"top_p": top_p, |
|
|
"repetition_penalty": repetition_penalty, |
|
|
"target_duration": None, |
|
|
"user_metadata": user_metadata, |
|
|
"skip_caption": False, |
|
|
"skip_language": False, |
|
|
"skip_genres": False, |
|
|
"generation_phase": "understand", |
|
|
"caption": "", |
|
|
"lyrics": "", |
|
|
}, |
|
|
use_constrained_decoding=use_constrained_decoding, |
|
|
constrained_decoding_debug=constrained_decoding_debug, |
|
|
stop_at_reasoning=False, |
|
|
) |
|
|
|
|
|
if not output_text: |
|
|
return {}, status |
|
|
|
|
|
|
|
|
metadata, _ = self.parse_lm_output(output_text) |
|
|
|
|
|
|
|
|
lyrics = self._extract_lyrics_from_output(output_text) |
|
|
if lyrics: |
|
|
metadata['lyrics'] = lyrics |
|
|
elif instrumental: |
|
|
|
|
|
metadata['lyrics'] = "[Instrumental]" |
|
|
|
|
|
|
|
|
metadata['instrumental'] = instrumental |
|
|
|
|
|
logger.info(f"Sample created successfully. Generated {metadata} fields") |
|
|
if constrained_decoding_debug: |
|
|
logger.debug(f"Generated metadata: {list(metadata.keys())}") |
|
|
logger.debug(f"Output text preview: {output_text[:300]}...") |
|
|
|
|
|
status_msg = f"✅ Sample created successfully\nGenerated fields: {metadata}" |
|
|
return metadata, status_msg |
|
|
|
|
|
def build_formatted_prompt_for_format( |
|
|
self, |
|
|
caption: str, |
|
|
lyrics: str, |
|
|
is_negative_prompt: bool = False, |
|
|
negative_prompt: str = "NO USER INPUT" |
|
|
) -> str: |
|
|
""" |
|
|
Build the chat-formatted prompt for format/rewrite mode. |
|
|
|
|
|
This formats user-provided caption and lyrics into a more detailed and specific |
|
|
musical description with metadata. |
|
|
|
|
|
Args: |
|
|
caption: User's caption/description of the music |
|
|
lyrics: User's lyrics |
|
|
is_negative_prompt: If True, builds unconditional prompt for CFG |
|
|
negative_prompt: Negative prompt for CFG (used when is_negative_prompt=True) |
|
|
|
|
|
Returns: |
|
|
Formatted prompt string |
|
|
|
|
|
Example: |
|
|
caption = "Latin pop, reggaeton, flamenco-pop" |
|
|
lyrics = "[Verse 1]\\nTengo un nudo..." |
|
|
prompt = handler.build_formatted_prompt_for_format(caption, lyrics) |
|
|
""" |
|
|
if self.llm_tokenizer is None: |
|
|
raise ValueError("LLM tokenizer is not initialized. Call initialize() first.") |
|
|
|
|
|
if is_negative_prompt: |
|
|
|
|
|
user_content = negative_prompt if negative_prompt and negative_prompt.strip() else "" |
|
|
else: |
|
|
|
|
|
user_content = f"# Caption\n{caption}\n\n# Lyric\n{lyrics}" |
|
|
|
|
|
return self.llm_tokenizer.apply_chat_template( |
|
|
[ |
|
|
{ |
|
|
"role": "system", |
|
|
"content": f"# Instruction\n{DEFAULT_LM_REWRITE_INSTRUCTION}\n\n" |
|
|
}, |
|
|
{ |
|
|
"role": "user", |
|
|
"content": user_content |
|
|
}, |
|
|
], |
|
|
tokenize=False, |
|
|
add_generation_prompt=True, |
|
|
) |
|
|
|
|
|
def format_sample_from_input( |
|
|
self, |
|
|
caption: str, |
|
|
lyrics: str, |
|
|
user_metadata: Optional[Dict[str, Any]] = None, |
|
|
temperature: float = 0.85, |
|
|
top_k: Optional[int] = None, |
|
|
top_p: Optional[float] = None, |
|
|
repetition_penalty: float = 1.0, |
|
|
use_constrained_decoding: bool = True, |
|
|
constrained_decoding_debug: bool = False, |
|
|
) -> Tuple[Dict[str, Any], str]: |
|
|
""" |
|
|
Format user-provided caption and lyrics into structured music metadata. |
|
|
|
|
|
This is the "Format" feature that takes user input and generates: |
|
|
- Enhanced caption with detailed music description |
|
|
- Metadata (bpm, duration, keyscale, language, timesignature) |
|
|
- Formatted lyrics (preserved from input) |
|
|
|
|
|
Note: cfg_scale and negative_prompt are not supported in format mode. |
|
|
|
|
|
Args: |
|
|
caption: User's caption/description (e.g., "Latin pop, reggaeton") |
|
|
lyrics: User's lyrics with structure tags |
|
|
user_metadata: Optional dict with user-provided metadata to constrain decoding. |
|
|
Supported keys: bpm, duration, keyscale, timesignature, language |
|
|
temperature: Sampling temperature for generation (0.0-2.0) |
|
|
top_k: Top-K sampling (None = disabled) |
|
|
top_p: Top-P (nucleus) sampling (None = disabled) |
|
|
repetition_penalty: Repetition penalty (1.0 = no penalty) |
|
|
use_constrained_decoding: Whether to use FSM-based constrained decoding |
|
|
constrained_decoding_debug: Whether to enable debug logging |
|
|
|
|
|
Returns: |
|
|
Tuple of (metadata_dict, status_message) |
|
|
metadata_dict contains: |
|
|
- bpm: int or str |
|
|
- caption: str (enhanced) |
|
|
- duration: int or str |
|
|
- keyscale: str |
|
|
- language: str |
|
|
- timesignature: str |
|
|
- lyrics: str (from input, possibly formatted) |
|
|
|
|
|
Example: |
|
|
caption = "Latin pop, reggaeton, flamenco-pop" |
|
|
lyrics = "[Verse 1]\\nTengo un nudo en la garganta..." |
|
|
metadata, status = handler.format_sample_from_input(caption, lyrics) |
|
|
print(metadata['caption']) # "A dramatic and powerful Latin pop track..." |
|
|
print(metadata['bpm']) # 100 |
|
|
""" |
|
|
if not getattr(self, "llm_initialized", False): |
|
|
return {}, "❌ 5Hz LM not initialized. Please initialize it first." |
|
|
|
|
|
if not caption or not caption.strip(): |
|
|
caption = "NO USER INPUT" |
|
|
if not lyrics or not lyrics.strip(): |
|
|
lyrics = "[Instrumental]" |
|
|
|
|
|
logger.info(f"Formatting sample from input: caption={caption[:50]}..., lyrics length={len(lyrics)}") |
|
|
|
|
|
|
|
|
formatted_prompt = self.build_formatted_prompt_for_format( |
|
|
caption=caption, |
|
|
lyrics=lyrics, |
|
|
) |
|
|
logger.debug(f"Formatted prompt for format: {formatted_prompt}") |
|
|
|
|
|
|
|
|
constrained_metadata = None |
|
|
if user_metadata: |
|
|
constrained_metadata = {} |
|
|
if user_metadata.get('bpm') is not None: |
|
|
try: |
|
|
bpm_val = int(user_metadata['bpm']) |
|
|
if bpm_val > 0: |
|
|
constrained_metadata['bpm'] = bpm_val |
|
|
except (ValueError, TypeError): |
|
|
pass |
|
|
if user_metadata.get('duration') is not None: |
|
|
try: |
|
|
dur_val = int(user_metadata['duration']) |
|
|
if dur_val > 0: |
|
|
constrained_metadata['duration'] = dur_val |
|
|
except (ValueError, TypeError): |
|
|
pass |
|
|
if user_metadata.get('keyscale'): |
|
|
constrained_metadata['keyscale'] = user_metadata['keyscale'] |
|
|
if user_metadata.get('timesignature'): |
|
|
constrained_metadata['timesignature'] = user_metadata['timesignature'] |
|
|
if user_metadata.get('language'): |
|
|
constrained_metadata['language'] = user_metadata['language'] |
|
|
|
|
|
|
|
|
if not constrained_metadata: |
|
|
constrained_metadata = None |
|
|
else: |
|
|
logger.info(f"Using user-provided metadata constraints: {constrained_metadata}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output_text, status = self.generate_from_formatted_prompt( |
|
|
formatted_prompt=formatted_prompt, |
|
|
cfg={ |
|
|
"temperature": temperature, |
|
|
"top_k": top_k, |
|
|
"top_p": top_p, |
|
|
"repetition_penalty": repetition_penalty, |
|
|
"target_duration": None, |
|
|
"user_metadata": constrained_metadata, |
|
|
"skip_caption": False, |
|
|
"skip_language": constrained_metadata.get('language') is not None if constrained_metadata else False, |
|
|
"skip_genres": False, |
|
|
"generation_phase": "understand", |
|
|
"caption": "", |
|
|
"lyrics": "", |
|
|
}, |
|
|
use_constrained_decoding=use_constrained_decoding, |
|
|
constrained_decoding_debug=constrained_decoding_debug, |
|
|
stop_at_reasoning=False, |
|
|
) |
|
|
|
|
|
if not output_text: |
|
|
return {}, status |
|
|
|
|
|
|
|
|
metadata, _ = self.parse_lm_output(output_text) |
|
|
|
|
|
|
|
|
formatted_lyrics = self._extract_lyrics_from_output(output_text) |
|
|
if formatted_lyrics: |
|
|
metadata['lyrics'] = formatted_lyrics |
|
|
else: |
|
|
|
|
|
metadata['lyrics'] = lyrics |
|
|
|
|
|
logger.info(f"Format completed successfully. Generated {metadata} fields") |
|
|
if constrained_decoding_debug: |
|
|
logger.debug(f"Generated metadata: {list(metadata.keys())}") |
|
|
logger.debug(f"Output text preview: {output_text[:300]}...") |
|
|
|
|
|
status_msg = f"✅ Format completed successfully\nGenerated fields: {', '.join(metadata.keys())}" |
|
|
return metadata, status_msg |
|
|
|
|
|
def generate_from_formatted_prompt( |
|
|
self, |
|
|
formatted_prompt: str, |
|
|
cfg: Optional[Dict[str, Any]] = None, |
|
|
use_constrained_decoding: bool = True, |
|
|
constrained_decoding_debug: bool = False, |
|
|
stop_at_reasoning: bool = False, |
|
|
) -> Tuple[str, str]: |
|
|
""" |
|
|
Generate raw LM text output from a pre-built formatted prompt. |
|
|
|
|
|
Args: |
|
|
formatted_prompt: Prompt that is already formatted by `build_formatted_prompt`. |
|
|
cfg: Optional dict supporting keys: |
|
|
- temperature (float) |
|
|
- cfg_scale (float) |
|
|
- negative_prompt (str) used when cfg_scale > 1 |
|
|
- top_k (int), top_p (float), repetition_penalty (float) |
|
|
- target_duration (float): Target duration in seconds for codes generation |
|
|
- generation_phase (str): "cot" or "codes" for phase-aware CFG |
|
|
use_constrained_decoding: Whether to use FSM-based constrained decoding |
|
|
constrained_decoding_debug: Whether to enable debug logging for constrained decoding |
|
|
stop_at_reasoning: If True, stop generation immediately after </think> tag (no audio codes) |
|
|
|
|
|
Returns: |
|
|
(output_text, status_message) |
|
|
|
|
|
Example: |
|
|
prompt = handler.build_formatted_prompt(caption, lyric) |
|
|
text, status = handler.generate_from_formatted_prompt(prompt, {"temperature": 0.7}) |
|
|
""" |
|
|
if not getattr(self, "llm_initialized", False): |
|
|
return "", "❌ 5Hz LM not initialized. Please initialize it first." |
|
|
if self.llm is None or self.llm_tokenizer is None: |
|
|
return "", "❌ 5Hz LM is missing model or tokenizer." |
|
|
|
|
|
cfg = cfg or {} |
|
|
temperature = cfg.get("temperature", 0.6) |
|
|
cfg_scale = cfg.get("cfg_scale", 1.0) |
|
|
negative_prompt = cfg.get("negative_prompt", "NO USER INPUT") |
|
|
top_k = cfg.get("top_k") |
|
|
top_p = cfg.get("top_p") |
|
|
repetition_penalty = cfg.get("repetition_penalty", 1.0) |
|
|
target_duration = cfg.get("target_duration") |
|
|
user_metadata = cfg.get("user_metadata") |
|
|
skip_caption = cfg.get("skip_caption", False) |
|
|
skip_language = cfg.get("skip_language", False) |
|
|
skip_genres = cfg.get("skip_genres", False) |
|
|
generation_phase = cfg.get("generation_phase", "cot") |
|
|
|
|
|
caption = cfg.get("caption", "") |
|
|
lyrics = cfg.get("lyrics", "") |
|
|
cot_text = cfg.get("cot_text", "") |
|
|
|
|
|
try: |
|
|
if self.llm_backend == "vllm": |
|
|
output_text = self._run_vllm( |
|
|
formatted_prompts=formatted_prompt, |
|
|
temperature=temperature, |
|
|
cfg_scale=cfg_scale, |
|
|
negative_prompt=negative_prompt, |
|
|
top_k=top_k, |
|
|
top_p=top_p, |
|
|
repetition_penalty=repetition_penalty, |
|
|
use_constrained_decoding=use_constrained_decoding, |
|
|
constrained_decoding_debug=constrained_decoding_debug, |
|
|
target_duration=target_duration, |
|
|
user_metadata=user_metadata, |
|
|
stop_at_reasoning=stop_at_reasoning, |
|
|
skip_genres=skip_genres, |
|
|
skip_caption=skip_caption, |
|
|
skip_language=skip_language, |
|
|
generation_phase=generation_phase, |
|
|
caption=caption, |
|
|
lyrics=lyrics, |
|
|
cot_text=cot_text, |
|
|
) |
|
|
return output_text, f"✅ Generated successfully (vllm) | length={len(output_text)}" |
|
|
|
|
|
|
|
|
output_text = self._run_pt( |
|
|
formatted_prompts=formatted_prompt, |
|
|
temperature=temperature, |
|
|
cfg_scale=cfg_scale, |
|
|
negative_prompt=negative_prompt, |
|
|
top_k=top_k, |
|
|
top_p=top_p, |
|
|
repetition_penalty=repetition_penalty, |
|
|
use_constrained_decoding=use_constrained_decoding, |
|
|
constrained_decoding_debug=constrained_decoding_debug, |
|
|
target_duration=target_duration, |
|
|
user_metadata=user_metadata, |
|
|
stop_at_reasoning=stop_at_reasoning, |
|
|
skip_genres=skip_genres, |
|
|
skip_caption=skip_caption, |
|
|
skip_language=skip_language, |
|
|
generation_phase=generation_phase, |
|
|
caption=caption, |
|
|
lyrics=lyrics, |
|
|
cot_text=cot_text, |
|
|
) |
|
|
return output_text, f"✅ Generated successfully (pt) | length={len(output_text)}" |
|
|
|
|
|
except Exception as e: |
|
|
return "", f"❌ Error generating from formatted prompt: {e}" |
|
|
|
|
|
def _generate_with_constrained_decoding( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor], |
|
|
max_new_tokens: int, |
|
|
temperature: float, |
|
|
top_k: Optional[int], |
|
|
top_p: Optional[float], |
|
|
repetition_penalty: float, |
|
|
pad_token_id: int, |
|
|
streamer: Optional[BaseStreamer], |
|
|
constrained_processor: Optional[MetadataConstrainedLogitsProcessor] = None, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Custom generation loop with constrained decoding support (non-CFG). |
|
|
This allows us to call update_state() after each token generation. |
|
|
""" |
|
|
model = self.llm |
|
|
device = self.device |
|
|
|
|
|
|
|
|
generated_ids = input_ids.clone() |
|
|
if attention_mask is not None: |
|
|
attn_mask = attention_mask.clone() |
|
|
else: |
|
|
attn_mask = torch.ones_like(input_ids) |
|
|
|
|
|
|
|
|
model_kwargs = {'attention_mask': attn_mask} |
|
|
|
|
|
|
|
|
past_key_values = None |
|
|
use_cache = hasattr(model, 'generation_config') and getattr(model.generation_config, 'use_cache', True) |
|
|
|
|
|
|
|
|
eos_token_id = self.llm_tokenizer.eos_token_id |
|
|
if eos_token_id is None: |
|
|
eos_token_id = pad_token_id |
|
|
|
|
|
|
|
|
logits_processor = self._build_logits_processor(repetition_penalty) |
|
|
|
|
|
with torch.no_grad(): |
|
|
for step in range(max_new_tokens): |
|
|
|
|
|
outputs = self._forward_pass(model, generated_ids, model_kwargs, past_key_values, use_cache) |
|
|
|
|
|
|
|
|
next_token_logits = outputs.logits[:, -1, :] |
|
|
|
|
|
|
|
|
if constrained_processor is not None: |
|
|
next_token_logits = constrained_processor(generated_ids, next_token_logits) |
|
|
|
|
|
|
|
|
for processor in logits_processor: |
|
|
next_token_logits = processor(generated_ids, next_token_logits) |
|
|
|
|
|
|
|
|
next_token_logits = self._apply_top_k_filter(next_token_logits, top_k) |
|
|
next_token_logits = self._apply_top_p_filter(next_token_logits, top_p) |
|
|
|
|
|
|
|
|
next_tokens = self._sample_tokens(next_token_logits, temperature) |
|
|
|
|
|
|
|
|
self._update_constrained_processor_state(constrained_processor, next_tokens) |
|
|
|
|
|
|
|
|
should_stop = self._check_eos_token(next_tokens, eos_token_id, pad_token_id) |
|
|
|
|
|
|
|
|
next_tokens_unsqueezed = next_tokens.unsqueeze(1) |
|
|
generated_ids = torch.cat([generated_ids, next_tokens_unsqueezed], dim=1) |
|
|
attn_mask = torch.cat([attn_mask, torch.ones((input_ids.shape[0], 1), device=device, dtype=attn_mask.dtype)], dim=1) |
|
|
model_kwargs['attention_mask'] = attn_mask |
|
|
|
|
|
|
|
|
if use_cache and hasattr(outputs, 'past_key_values'): |
|
|
past_key_values = outputs.past_key_values |
|
|
|
|
|
|
|
|
if streamer is not None: |
|
|
streamer.put(next_tokens_unsqueezed) |
|
|
|
|
|
if should_stop: |
|
|
break |
|
|
|
|
|
if streamer is not None: |
|
|
streamer.end() |
|
|
|
|
|
return generated_ids |
|
|
|
|
|
def _generate_with_cfg_custom( |
|
|
self, |
|
|
batch_input_ids: torch.Tensor, |
|
|
batch_attention_mask: Optional[torch.Tensor], |
|
|
max_new_tokens: int, |
|
|
temperature: float, |
|
|
cfg_scale: float, |
|
|
top_k: Optional[int], |
|
|
top_p: Optional[float], |
|
|
repetition_penalty: float, |
|
|
pad_token_id: int, |
|
|
streamer: Optional[BaseStreamer], |
|
|
constrained_processor: Optional[MetadataConstrainedLogitsProcessor] = None, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Custom CFG generation loop that: |
|
|
1. Processes both conditional and unconditional sequences in parallel |
|
|
2. Applies CFG formula to logits |
|
|
3. Samples tokens only for conditional sequences |
|
|
4. Applies the same sampled tokens to both conditional and unconditional sequences |
|
|
5. Optionally applies constrained decoding via FSM-based logits processor |
|
|
|
|
|
Batch format: [cond_input, uncond_input] |
|
|
""" |
|
|
model = self.llm |
|
|
device = self.device |
|
|
batch_size = batch_input_ids.shape[0] // 2 |
|
|
cond_start_idx = 0 |
|
|
uncond_start_idx = batch_size |
|
|
|
|
|
|
|
|
generated_ids = batch_input_ids.clone() |
|
|
if batch_attention_mask is not None: |
|
|
attention_mask = batch_attention_mask.clone() |
|
|
else: |
|
|
attention_mask = torch.ones_like(batch_input_ids) |
|
|
|
|
|
|
|
|
model_kwargs = {} |
|
|
if batch_attention_mask is not None: |
|
|
model_kwargs['attention_mask'] = attention_mask |
|
|
|
|
|
|
|
|
past_key_values = None |
|
|
use_cache = hasattr(model, 'generation_config') and getattr(model.generation_config, 'use_cache', True) |
|
|
|
|
|
|
|
|
eos_token_id = self.llm_tokenizer.eos_token_id |
|
|
if eos_token_id is None: |
|
|
eos_token_id = pad_token_id |
|
|
|
|
|
|
|
|
logits_processor = self._build_logits_processor(repetition_penalty) |
|
|
|
|
|
with torch.no_grad(): |
|
|
for step in range(max_new_tokens): |
|
|
|
|
|
outputs = self._forward_pass(model, generated_ids, model_kwargs, past_key_values, use_cache) |
|
|
|
|
|
|
|
|
next_token_logits = outputs.logits[:, -1, :] |
|
|
|
|
|
|
|
|
cond_logits = next_token_logits[cond_start_idx:cond_start_idx+batch_size] |
|
|
uncond_logits = next_token_logits[uncond_start_idx:uncond_start_idx+batch_size] |
|
|
|
|
|
|
|
|
cfg_logits = uncond_logits + cfg_scale * (cond_logits - uncond_logits) |
|
|
|
|
|
|
|
|
if constrained_processor is not None: |
|
|
current_input_ids = generated_ids[cond_start_idx:cond_start_idx+batch_size] |
|
|
cfg_logits = constrained_processor(current_input_ids, cfg_logits) |
|
|
|
|
|
|
|
|
|
|
|
current_input_ids = generated_ids[cond_start_idx:cond_start_idx+batch_size] |
|
|
for processor in logits_processor: |
|
|
cfg_logits = processor(current_input_ids, cfg_logits) |
|
|
|
|
|
|
|
|
cfg_logits = self._apply_top_k_filter(cfg_logits, top_k) |
|
|
cfg_logits = self._apply_top_p_filter(cfg_logits, top_p) |
|
|
|
|
|
|
|
|
next_tokens = self._sample_tokens(cfg_logits, temperature) |
|
|
|
|
|
|
|
|
self._update_constrained_processor_state(constrained_processor, next_tokens) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
should_stop = self._check_eos_token(next_tokens, eos_token_id, pad_token_id) |
|
|
|
|
|
|
|
|
next_tokens_unsqueezed = next_tokens.unsqueeze(1) |
|
|
generated_ids = torch.cat([generated_ids, next_tokens_unsqueezed.repeat(2, 1)], dim=1) |
|
|
attention_mask = torch.cat([attention_mask, torch.ones((batch_size*2, 1), device=device, dtype=attention_mask.dtype)], dim=1) |
|
|
model_kwargs['attention_mask'] = attention_mask |
|
|
|
|
|
|
|
|
if use_cache and hasattr(outputs, 'past_key_values'): |
|
|
past_key_values = outputs.past_key_values |
|
|
|
|
|
|
|
|
if streamer is not None: |
|
|
streamer.put(next_tokens_unsqueezed) |
|
|
|
|
|
|
|
|
if should_stop: |
|
|
break |
|
|
|
|
|
if streamer is not None: |
|
|
streamer.end() |
|
|
|
|
|
|
|
|
|
|
|
return generated_ids |
|
|
|
|
|
def parse_lm_output(self, output_text: str) -> Tuple[Dict[str, Any], str]: |
|
|
""" |
|
|
Parse LM output to extract metadata and audio codes. |
|
|
|
|
|
Expected format: |
|
|
<think> |
|
|
bpm: 73 |
|
|
caption: A calm piano melody |
|
|
duration: 273 |
|
|
genres: Chinese folk |
|
|
keyscale: G major |
|
|
language: en |
|
|
timesignature: 4 |
|
|
</think> |
|
|
|
|
|
<|audio_code_56535|><|audio_code_62918|>... |
|
|
|
|
|
Returns: |
|
|
Tuple of (metadata_dict, audio_codes_string) |
|
|
""" |
|
|
debug_output_text = output_text.split("</think>")[0] |
|
|
logger.debug(f"Debug output text: {debug_output_text}") |
|
|
metadata = {} |
|
|
audio_codes = "" |
|
|
|
|
|
import re |
|
|
|
|
|
|
|
|
code_pattern = r'<\|audio_code_\d+\|>' |
|
|
code_matches = re.findall(code_pattern, output_text) |
|
|
if code_matches: |
|
|
audio_codes = "".join(code_matches) |
|
|
|
|
|
|
|
|
|
|
|
reasoning_patterns = [ |
|
|
r'<think>(.*?)</think>', |
|
|
r'<think>(.*?)</think>', |
|
|
r'<reasoning>(.*?)</reasoning>', |
|
|
] |
|
|
|
|
|
reasoning_text = None |
|
|
for pattern in reasoning_patterns: |
|
|
match = re.search(pattern, output_text, re.DOTALL) |
|
|
if match: |
|
|
reasoning_text = match.group(1).strip() |
|
|
break |
|
|
|
|
|
|
|
|
if not reasoning_text: |
|
|
|
|
|
lines_before_codes = output_text.split('<|audio_code_')[0] if '<|audio_code_' in output_text else output_text |
|
|
reasoning_text = lines_before_codes.strip() |
|
|
|
|
|
|
|
|
if reasoning_text: |
|
|
lines = reasoning_text.split('\n') |
|
|
current_key = None |
|
|
current_value_lines = [] |
|
|
|
|
|
def save_current_field(): |
|
|
"""Save the accumulated field value""" |
|
|
nonlocal current_key, current_value_lines |
|
|
if current_key and current_value_lines: |
|
|
|
|
|
value = '\n'.join(current_value_lines) |
|
|
|
|
|
if current_key == 'bpm': |
|
|
try: |
|
|
metadata['bpm'] = int(value.strip()) |
|
|
except: |
|
|
metadata['bpm'] = value.strip() |
|
|
elif current_key == 'caption': |
|
|
|
|
|
metadata['caption'] = MetadataConstrainedLogitsProcessor.postprocess_caption(value) |
|
|
elif current_key == 'duration': |
|
|
try: |
|
|
metadata['duration'] = int(value.strip()) |
|
|
except: |
|
|
metadata['duration'] = value.strip() |
|
|
elif current_key == 'genres': |
|
|
metadata['genres'] = value.strip() |
|
|
elif current_key == 'keyscale': |
|
|
metadata['keyscale'] = value.strip() |
|
|
elif current_key == 'language': |
|
|
metadata['language'] = value.strip() |
|
|
elif current_key == 'timesignature': |
|
|
metadata['timesignature'] = value.strip() |
|
|
|
|
|
current_key = None |
|
|
current_value_lines = [] |
|
|
|
|
|
for line in lines: |
|
|
|
|
|
if line.strip().startswith('<'): |
|
|
continue |
|
|
|
|
|
|
|
|
if line and not line[0].isspace() and ':' in line: |
|
|
|
|
|
save_current_field() |
|
|
|
|
|
|
|
|
parts = line.split(':', 1) |
|
|
if len(parts) == 2: |
|
|
current_key = parts[0].strip().lower() |
|
|
|
|
|
first_value = parts[1] |
|
|
if first_value.strip(): |
|
|
current_value_lines.append(first_value) |
|
|
elif line.startswith(' ') or line.startswith('\t'): |
|
|
|
|
|
if current_key: |
|
|
current_value_lines.append(line) |
|
|
|
|
|
|
|
|
save_current_field() |
|
|
|
|
|
return metadata, audio_codes |
|
|
|
|
|
@contextmanager |
|
|
def _load_model_context(self): |
|
|
""" |
|
|
Context manager to load a model to GPU and offload it back to CPU after use. |
|
|
Only used for PyTorch backend when offload_to_cpu is True. |
|
|
""" |
|
|
if not self.offload_to_cpu: |
|
|
yield |
|
|
return |
|
|
|
|
|
|
|
|
if self.llm_backend == "vllm": |
|
|
yield |
|
|
return |
|
|
|
|
|
model = self.llm |
|
|
if model is None: |
|
|
yield |
|
|
return |
|
|
|
|
|
|
|
|
logger.info(f"Loading LLM to {self.device}") |
|
|
start_time = time.time() |
|
|
if hasattr(model, "to"): |
|
|
model.to(self.device).to(self.dtype) |
|
|
load_time = time.time() - start_time |
|
|
logger.info(f"Loaded LLM to {self.device} in {load_time:.4f}s") |
|
|
|
|
|
try: |
|
|
yield |
|
|
finally: |
|
|
|
|
|
logger.info(f"Offloading LLM to CPU") |
|
|
start_time = time.time() |
|
|
if hasattr(model, "to"): |
|
|
model.to("cpu") |
|
|
torch.cuda.empty_cache() |
|
|
offload_time = time.time() - start_time |
|
|
logger.info(f"Offloaded LLM to CPU in {offload_time:.4f}s") |
|
|
|
|
|
def get_hf_model_for_scoring(self): |
|
|
""" |
|
|
Get HuggingFace model for perplexity scoring. |
|
|
|
|
|
For vllm backend, loads HuggingFace model from disk (weights are cached by transformers). |
|
|
For pt backend, returns the existing model. |
|
|
|
|
|
Returns: |
|
|
HuggingFace model instance |
|
|
""" |
|
|
if self.llm_backend == "pt": |
|
|
|
|
|
return self.llm |
|
|
|
|
|
elif self.llm_backend == "vllm": |
|
|
|
|
|
|
|
|
if self._hf_model_for_scoring is None: |
|
|
logger.info("Loading HuggingFace model for scoring (from checkpoint)") |
|
|
|
|
|
|
|
|
model_runner = self.llm.model_runner |
|
|
model_path = model_runner.config.model |
|
|
|
|
|
|
|
|
|
|
|
import time |
|
|
start_time = time.time() |
|
|
self._hf_model_for_scoring = AutoModelForCausalLM.from_pretrained( |
|
|
model_path, |
|
|
trust_remote_code=True, |
|
|
torch_dtype=self.dtype |
|
|
) |
|
|
load_time = time.time() - start_time |
|
|
logger.info(f"HuggingFace model loaded in {load_time:.2f}s") |
|
|
|
|
|
|
|
|
device = next(model_runner.model.parameters()).device |
|
|
self._hf_model_for_scoring = self._hf_model_for_scoring.to(device) |
|
|
self._hf_model_for_scoring.eval() |
|
|
|
|
|
logger.info(f"HuggingFace model for scoring ready on {device}") |
|
|
|
|
|
return self._hf_model_for_scoring |
|
|
|
|
|
else: |
|
|
raise ValueError(f"Unknown backend: {self.llm_backend}") |
|
|
|