File size: 1,115 Bytes
de111d8 18c9410 de111d8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 | # CAM++ MLX Model Usage Example
import mlx.core as mx
import numpy as np
from model import CAMPPModel
import json
def load_model(model_path="."):
# Load config
with open(f"{model_path}/config.json", "r") as f:
config = json.load(f)
# Initialize model
model = CAMPPModel(
input_dim=config["input_dim"],
embedding_dim=config["embedding_dim"],
input_channels=config.get("input_channels", 64)
)
# Load weights
weights = mx.load(f"{model_path}/weights.npz")
model.load_weights(weights)
return model
def extract_speaker_embedding(model, audio_features):
# audio_features: (batch, features, time) - e.g., mel-spectrogram
# Returns: speaker embedding vector
mx.eval(model.parameters()) # Ensure weights are loaded
with mx.no_grad():
embedding = model(audio_features)
return embedding
# Example usage:
# model = load_model()
# features = mx.random.normal((1, 80, 200)) # Example input
# embedding = extract_speaker_embedding(model, features)
# print(f"Speaker embedding shape: {embedding.shape}")
|