Spaces:
Running on Zero
Running on Zero
| import os | |
| import torch | |
| import logging | |
| from omegaconf import OmegaConf | |
| from funcineforge.utils.hinter import get_logger | |
| from funcineforge.models.utils import dtype_map | |
| from funcineforge.datasets import FunCineForgeDS | |
| class AutoFrontend: | |
| def __init__( | |
| self, | |
| ckpt_path: str, | |
| config_path: str, | |
| output_dir: str, | |
| device: str = "cuda:0" | |
| ): | |
| self.logger = get_logger(log_level=logging.INFO, local_rank=1, world_size=1) | |
| self.device = device | |
| self.output_dir = output_dir | |
| self.lm_model = None | |
| self.fm_model = None | |
| self.voc_model = None | |
| self.model = None | |
| self.index_ds_class = None | |
| self.dataset_conf = None | |
| self.kwargs = OmegaConf.load(config_path) | |
| if device.startswith("cuda"): | |
| try: | |
| device_id = int(device.split(":")[-1]) | |
| torch.cuda.set_device(device_id) | |
| except (ValueError, IndexError): | |
| self.logger.warning(f"Invalid cuda device string {device}, defaulting to 0") | |
| torch.cuda.set_device(0) | |
| else: | |
| self.logger.info(f"Running on CPU") | |
| lm_ckpt_path = os.path.join(ckpt_path, "funcineforge_zh_en/llm/ds-model.pt.best/mp_rank_00_model_states.pt") | |
| fm_ckpt_path = os.path.join(ckpt_path, "funcineforge_zh_en/flow/ds-model.pt.best/mp_rank_00_model_states.pt") | |
| voc_ckpt_path = os.path.join(ckpt_path, "funcineforge_zh_en/vocoder/ds-model.pt.best/avg_5_removewn.pt") | |
| lm_exp_dir, lm_model_name, lm_ckpt_id, _ = lm_ckpt_path.rsplit("/", 3) | |
| self.logger.info(f"init LM model form {lm_ckpt_path}") | |
| from funcineforge import AutoModel | |
| self.lm_model = (AutoModel( | |
| model=os.path.join(lm_exp_dir, lm_model_name), | |
| init_param=lm_ckpt_path, | |
| output_dir=None, | |
| device=device, | |
| )) | |
| self.lm_model.model.to(dtype_map[self.kwargs.get("llm_dtype", "fp32")]) | |
| fm_exp_dir, fm_model_name, fm_ckpt_id, _ = fm_ckpt_path.rsplit("/", 3) | |
| self.logger.info(f"build FM model form {fm_ckpt_path}") | |
| self.fm_model = AutoModel( | |
| model=os.path.join(fm_exp_dir, fm_model_name), | |
| init_param=fm_ckpt_path, | |
| output_dir=None, | |
| device=device, | |
| ) | |
| self.fm_model.model.to(dtype_map[self.kwargs.get("fm_dtype", "fp32")]) | |
| voc_exp_dir, voc_model_name, voc_ckpt_id, _ = voc_ckpt_path.rsplit("/", 3) | |
| self.logger.info(f"build VOC model form {voc_ckpt_path}") | |
| self.voc_model = AutoModel( | |
| model=os.path.join(voc_exp_dir, voc_model_name), | |
| init_param=voc_ckpt_path, | |
| output_dir=None, | |
| device=device, | |
| ) | |
| self.voc_model.model.to(dtype_map[self.kwargs.get("voc_dtype", "fp32")]) | |
| self.logger.info(f"build inference model {self.kwargs.get('model')}") | |
| self.kwargs["output_dir"] = output_dir | |
| self.kwargs["tokenizer"] = None | |
| self.model = AutoModel( | |
| **self.kwargs, | |
| lm_model=self.lm_model, | |
| fm_model=self.fm_model, | |
| voc_model=self.voc_model, | |
| ) | |
| self.dataset_conf = self.kwargs.get("dataset_conf") | |
| def inference(self, jsonl_path: str): | |
| if not self.model: | |
| raise RuntimeError("Model class not initialized.") | |
| dataset = FunCineForgeDS(jsonl_path, **self.dataset_conf) | |
| self.logger.info(f"Starting inference on {len(dataset)} items...") | |
| self.model.inference(input=dataset, input_len=len(dataset)) | |
| self.logger.info("Inference finished.") |