Spaces:
Running on Zero
Running on Zero
File size: 3,723 Bytes
03022ee | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 | import os
import torch
import logging
from omegaconf import OmegaConf
from funcineforge.utils.hinter import get_logger
from funcineforge.models.utils import dtype_map
from funcineforge.datasets import FunCineForgeDS
class AutoFrontend:
def __init__(
self,
ckpt_path: str,
config_path: str,
output_dir: str,
device: str = "cuda:0"
):
self.logger = get_logger(log_level=logging.INFO, local_rank=1, world_size=1)
self.device = device
self.output_dir = output_dir
self.lm_model = None
self.fm_model = None
self.voc_model = None
self.model = None
self.index_ds_class = None
self.dataset_conf = None
self.kwargs = OmegaConf.load(config_path)
if device.startswith("cuda"):
try:
device_id = int(device.split(":")[-1])
torch.cuda.set_device(device_id)
except (ValueError, IndexError):
self.logger.warning(f"Invalid cuda device string {device}, defaulting to 0")
torch.cuda.set_device(0)
else:
self.logger.info(f"Running on CPU")
lm_ckpt_path = os.path.join(ckpt_path, "funcineforge_zh_en/llm/ds-model.pt.best/mp_rank_00_model_states.pt")
fm_ckpt_path = os.path.join(ckpt_path, "funcineforge_zh_en/flow/ds-model.pt.best/mp_rank_00_model_states.pt")
voc_ckpt_path = os.path.join(ckpt_path, "funcineforge_zh_en/vocoder/ds-model.pt.best/avg_5_removewn.pt")
lm_exp_dir, lm_model_name, lm_ckpt_id, _ = lm_ckpt_path.rsplit("/", 3)
self.logger.info(f"init LM model form {lm_ckpt_path}")
from funcineforge import AutoModel
self.lm_model = (AutoModel(
model=os.path.join(lm_exp_dir, lm_model_name),
init_param=lm_ckpt_path,
output_dir=None,
device=device,
))
self.lm_model.model.to(dtype_map[self.kwargs.get("llm_dtype", "fp32")])
fm_exp_dir, fm_model_name, fm_ckpt_id, _ = fm_ckpt_path.rsplit("/", 3)
self.logger.info(f"build FM model form {fm_ckpt_path}")
self.fm_model = AutoModel(
model=os.path.join(fm_exp_dir, fm_model_name),
init_param=fm_ckpt_path,
output_dir=None,
device=device,
)
self.fm_model.model.to(dtype_map[self.kwargs.get("fm_dtype", "fp32")])
voc_exp_dir, voc_model_name, voc_ckpt_id, _ = voc_ckpt_path.rsplit("/", 3)
self.logger.info(f"build VOC model form {voc_ckpt_path}")
self.voc_model = AutoModel(
model=os.path.join(voc_exp_dir, voc_model_name),
init_param=voc_ckpt_path,
output_dir=None,
device=device,
)
self.voc_model.model.to(dtype_map[self.kwargs.get("voc_dtype", "fp32")])
self.logger.info(f"build inference model {self.kwargs.get('model')}")
self.kwargs["output_dir"] = output_dir
self.kwargs["tokenizer"] = None
self.model = AutoModel(
**self.kwargs,
lm_model=self.lm_model,
fm_model=self.fm_model,
voc_model=self.voc_model,
)
self.dataset_conf = self.kwargs.get("dataset_conf")
def inference(self, jsonl_path: str):
if not self.model:
raise RuntimeError("Model class not initialized.")
dataset = FunCineForgeDS(jsonl_path, **self.dataset_conf)
self.logger.info(f"Starting inference on {len(dataset)} items...")
self.model.inference(input=dataset, input_len=len(dataset))
self.logger.info("Inference finished.") |