ray-006's picture
Upload 43 files
fc605f9 verified
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
from typing import Optional
import torch
from audiobox_aesthetics.infer import AesPredictor
COLUMN_MAP = {
"CE": "ContentEnjoyment",
"CU": "ContentUsefulness",
"PC": "ProductionComplexity",
"PQ": "ProductionQuality",
}
class Aesthetic(torch.nn.Module):
def __init__(
self,
checkpoint: Optional[str] = None,
device: Optional[torch.device] = None,
):
super().__init__()
self.model = AesPredictor(
checkpoint_pth=checkpoint,
data_col="wav",
)
self.device = device or torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
def __call__(
self,
target_wavs: list[torch.Tensor],
target_wavs_sample_rate: int = 48_000,
**kwargs,
) -> dict[str, list[float]]:
result = self.model.forward(
[
{
"wav": wav[None] if wav.ndim == 1 else wav,
"sample_rate": target_wavs_sample_rate,
}
for wav in target_wavs
]
)
return {
long_name: [x[shortname] for x in result]
for shortname, long_name in COLUMN_MAP.items()
}