Fun-CineForge-Demo / funcineforge /auto /auto_frontend.py
xuan3986's picture
Upload 111 files
03022ee verified
import os
import torch
import logging
from omegaconf import OmegaConf
from funcineforge.utils.hinter import get_logger
from funcineforge.models.utils import dtype_map
from funcineforge.datasets import FunCineForgeDS
class AutoFrontend:
def __init__(
self,
ckpt_path: str,
config_path: str,
output_dir: str,
device: str = "cuda:0"
):
self.logger = get_logger(log_level=logging.INFO, local_rank=1, world_size=1)
self.device = device
self.output_dir = output_dir
self.lm_model = None
self.fm_model = None
self.voc_model = None
self.model = None
self.index_ds_class = None
self.dataset_conf = None
self.kwargs = OmegaConf.load(config_path)
if device.startswith("cuda"):
try:
device_id = int(device.split(":")[-1])
torch.cuda.set_device(device_id)
except (ValueError, IndexError):
self.logger.warning(f"Invalid cuda device string {device}, defaulting to 0")
torch.cuda.set_device(0)
else:
self.logger.info(f"Running on CPU")
lm_ckpt_path = os.path.join(ckpt_path, "funcineforge_zh_en/llm/ds-model.pt.best/mp_rank_00_model_states.pt")
fm_ckpt_path = os.path.join(ckpt_path, "funcineforge_zh_en/flow/ds-model.pt.best/mp_rank_00_model_states.pt")
voc_ckpt_path = os.path.join(ckpt_path, "funcineforge_zh_en/vocoder/ds-model.pt.best/avg_5_removewn.pt")
lm_exp_dir, lm_model_name, lm_ckpt_id, _ = lm_ckpt_path.rsplit("/", 3)
self.logger.info(f"init LM model form {lm_ckpt_path}")
from funcineforge import AutoModel
self.lm_model = (AutoModel(
model=os.path.join(lm_exp_dir, lm_model_name),
init_param=lm_ckpt_path,
output_dir=None,
device=device,
))
self.lm_model.model.to(dtype_map[self.kwargs.get("llm_dtype", "fp32")])
fm_exp_dir, fm_model_name, fm_ckpt_id, _ = fm_ckpt_path.rsplit("/", 3)
self.logger.info(f"build FM model form {fm_ckpt_path}")
self.fm_model = AutoModel(
model=os.path.join(fm_exp_dir, fm_model_name),
init_param=fm_ckpt_path,
output_dir=None,
device=device,
)
self.fm_model.model.to(dtype_map[self.kwargs.get("fm_dtype", "fp32")])
voc_exp_dir, voc_model_name, voc_ckpt_id, _ = voc_ckpt_path.rsplit("/", 3)
self.logger.info(f"build VOC model form {voc_ckpt_path}")
self.voc_model = AutoModel(
model=os.path.join(voc_exp_dir, voc_model_name),
init_param=voc_ckpt_path,
output_dir=None,
device=device,
)
self.voc_model.model.to(dtype_map[self.kwargs.get("voc_dtype", "fp32")])
self.logger.info(f"build inference model {self.kwargs.get('model')}")
self.kwargs["output_dir"] = output_dir
self.kwargs["tokenizer"] = None
self.model = AutoModel(
**self.kwargs,
lm_model=self.lm_model,
fm_model=self.fm_model,
voc_model=self.voc_model,
)
self.dataset_conf = self.kwargs.get("dataset_conf")
def inference(self, jsonl_path: str):
if not self.model:
raise RuntimeError("Model class not initialized.")
dataset = FunCineForgeDS(jsonl_path, **self.dataset_conf)
self.logger.info(f"Starting inference on {len(dataset)} items...")
self.model.inference(input=dataset, input_len=len(dataset))
self.logger.info("Inference finished.")