kazega0/KazEGA
Viewer • Updated • 96.6k • 62
Multi-task HuBERT model for Kazakh speech, jointly predicting emotion, gender, and age from raw audio. Trained on the KazEGA corpus.
MultiTaskBackbone — a frozen/finetuned hubert_base encoder
(facebook/hubert-base-ls960) with per-task learnable softmax weights over hidden
layers, mean-pooled over time, feeding three independent MLP classification heads.
facebook/hubert-base-ls960| task | classes | labels |
|---|---|---|
| emotion | 7 | angry, disgusted, fearful, happy, neutral, sad, surprised |
| gender | 2 | F, M |
| age | 4 | adult, child, senior, young |
pip install torch transformers huggingface_hub librosa numpy
import json
import librosa
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from huggingface_hub import hf_hub_download
from transformers import HubertModel, Wav2Vec2FeatureExtractor
REPO_ID = "kazega0/KazEGA-Hubert"
PRETRAINED = "facebook/hubert-base-ls960"
SAMPLE_RATE = 16_000
MAX_LENGTH = 160_000
AUDIO_PATH = "sample.wav"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class MultiTaskHubert(nn.Module):
def __init__(self, num_emotions, num_genders, num_ages, pretrained=PRETRAINED):
super().__init__()
self.hubert = HubertModel.from_pretrained(pretrained, output_hidden_states=True)
hidden_size = self.hubert.config.hidden_size
num_layers = self.hubert.config.num_hidden_layers + 1
self.emotion_weights = nn.Parameter(torch.ones(num_layers))
self.gender_weights = nn.Parameter(torch.ones(num_layers))
self.age_weights = nn.Parameter(torch.ones(num_layers))
self.emotion_head = nn.Sequential(
nn.Linear(hidden_size, 256), nn.ReLU(), nn.Dropout(0.2),
nn.Linear(256, num_emotions))
self.gender_head = nn.Sequential(
nn.Linear(hidden_size, 256), nn.ReLU(), nn.Dropout(0.1),
nn.Linear(256, num_genders))
self.age_head = nn.Sequential(
nn.Linear(hidden_size, 256), nn.ReLU(), nn.Dropout(0.1),
nn.Linear(256, num_ages))
def forward(self, input_values, input_length):
hidden = torch.stack(self.hubert(input_values).hidden_states, dim=0)
feat_len = int(self.hubert._get_feat_extract_output_lengths(torch.tensor(input_length)))
hidden = hidden[:, :, :feat_len, :]
def pool(layer_weights):
w = torch.softmax(layer_weights, dim=0)
return (w.view(-1, 1, 1, 1) * hidden).sum(dim=0).mean(dim=1)
return (self.emotion_head(pool(self.emotion_weights)),
self.gender_head(pool(self.gender_weights)),
self.age_head(pool(self.age_weights)))
ckpt = torch.load(hf_hub_download(REPO_ID, "model.pt"), map_location="cpu", weights_only=False)
encoders = json.load(open(hf_hub_download(REPO_ID, "label_encoders.json")))
id2label = {task: {idx: name for name, idx in m.items()} for task, m in encoders.items()}
model = MultiTaskHubert(ckpt["num_emotions"], ckpt["num_genders"], ckpt["num_ages"])
model.load_state_dict(ckpt["model_state_dict"])
model.to(DEVICE).eval()
processor = Wav2Vec2FeatureExtractor.from_pretrained(PRETRAINED)
audio, _ = librosa.load(AUDIO_PATH, sr=SAMPLE_RATE, mono=True)
audio = audio[:MAX_LENGTH]
raw_length = len(audio)
audio = np.pad(audio, (0, MAX_LENGTH - raw_length))
input_values = processor(audio, sampling_rate=SAMPLE_RATE,
return_tensors="pt").input_values.to(DEVICE)
with torch.no_grad():
logits = model(input_values, raw_length)
for task, task_logits in zip(("emotion", "gender", "age"), logits):
probs = F.softmax(task_logits[0], dim=0)
idx = int(probs.argmax())
print(f"{task:8s} {id2label[task][idx]:10s} ({probs[idx]:.1%})")
Base model
facebook/hubert-base-ls960