import os import torch import logging from omegaconf import OmegaConf from funcineforge.utils.hinter import get_logger from funcineforge.models.utils import dtype_map from funcineforge.datasets import FunCineForgeDS class AutoFrontend: def __init__( self, ckpt_path: str, config_path: str, output_dir: str, device: str = "cuda:0" ): self.logger = get_logger(log_level=logging.INFO, local_rank=1, world_size=1) self.device = device self.output_dir = output_dir self.lm_model = None self.fm_model = None self.voc_model = None self.model = None self.index_ds_class = None self.dataset_conf = None self.kwargs = OmegaConf.load(config_path) if device.startswith("cuda"): try: device_id = int(device.split(":")[-1]) torch.cuda.set_device(device_id) except (ValueError, IndexError): self.logger.warning(f"Invalid cuda device string {device}, defaulting to 0") torch.cuda.set_device(0) else: self.logger.info(f"Running on CPU") lm_ckpt_path = os.path.join(ckpt_path, "funcineforge_zh_en/llm/ds-model.pt.best/mp_rank_00_model_states.pt") fm_ckpt_path = os.path.join(ckpt_path, "funcineforge_zh_en/flow/ds-model.pt.best/mp_rank_00_model_states.pt") voc_ckpt_path = os.path.join(ckpt_path, "funcineforge_zh_en/vocoder/ds-model.pt.best/avg_5_removewn.pt") lm_exp_dir, lm_model_name, lm_ckpt_id, _ = lm_ckpt_path.rsplit("/", 3) self.logger.info(f"init LM model form {lm_ckpt_path}") from funcineforge import AutoModel self.lm_model = (AutoModel( model=os.path.join(lm_exp_dir, lm_model_name), init_param=lm_ckpt_path, output_dir=None, device=device, )) self.lm_model.model.to(dtype_map[self.kwargs.get("llm_dtype", "fp32")]) fm_exp_dir, fm_model_name, fm_ckpt_id, _ = fm_ckpt_path.rsplit("/", 3) self.logger.info(f"build FM model form {fm_ckpt_path}") self.fm_model = AutoModel( model=os.path.join(fm_exp_dir, fm_model_name), init_param=fm_ckpt_path, output_dir=None, device=device, ) self.fm_model.model.to(dtype_map[self.kwargs.get("fm_dtype", "fp32")]) voc_exp_dir, voc_model_name, voc_ckpt_id, _ = voc_ckpt_path.rsplit("/", 3) self.logger.info(f"build VOC model form {voc_ckpt_path}") self.voc_model = AutoModel( model=os.path.join(voc_exp_dir, voc_model_name), init_param=voc_ckpt_path, output_dir=None, device=device, ) self.voc_model.model.to(dtype_map[self.kwargs.get("voc_dtype", "fp32")]) self.logger.info(f"build inference model {self.kwargs.get('model')}") self.kwargs["output_dir"] = output_dir self.kwargs["tokenizer"] = None self.model = AutoModel( **self.kwargs, lm_model=self.lm_model, fm_model=self.fm_model, voc_model=self.voc_model, ) self.dataset_conf = self.kwargs.get("dataset_conf") def inference(self, jsonl_path: str): if not self.model: raise RuntimeError("Model class not initialized.") dataset = FunCineForgeDS(jsonl_path, **self.dataset_conf) self.logger.info(f"Starting inference on {len(dataset)} items...") self.model.inference(input=dataset, input_len=len(dataset)) self.logger.info("Inference finished.")