--- license: mit language: en tags: - audio-classification - carnatic-music - raga-classification - indian-classical-music datasets: - sarayusapa/carnatic-ragas metrics: - accuracy - f1 pipeline_tag: audio-classification --- # SAM-Audio: Carnatic Raga Classifier A CNN + Segment Attention model for classifying Carnatic ragas from audio. ## Model Details - **Architecture**: SAM-Audio (CNN mel-spectrogram encoder + latent segmentation tokens + masked segment prediction + contrastive learning) - **Parameters**: 2.6M - **Training data**: [sarayusapa/carnatic-ragas](https://huggingface.co/datasets/sarayusapa/carnatic-ragas) with 3x pitch-shift augmentation - **Best validation accuracy**: 99.62% - **Best epoch**: 17 ## Supported Ragas | ID | Raga | |----|------| | 0 | Amritavarshini | | 1 | Hamsanaadam | | 2 | Kalyani | | 3 | Kharaharapriya | | 4 | Mayamalavagoulai | | 5 | Sindhubhairavi | | 6 | Todi | | 7 | Varali | ## Usage ```python import torch import librosa from safetensors.torch import load_file # Load model from train import SAMAudioModel config = json.load(open("config.json")) model = SAMAudioModel( encoder_config=config["encoder"], num_classes=config["num_classes"], num_segments=config["num_segments"], ) state_dict = load_file("model.safetensors") model.load_state_dict(state_dict) model.eval() # Load audio y, sr = librosa.load("audio.mp3", sr=16000, mono=True) waveform = torch.from_numpy(y[:320000]).float().unsqueeze(0) # Predict with torch.no_grad(): outputs = model(input_audio=waveform) probs = torch.softmax(outputs["raga_logits"], dim=-1) pred = probs.argmax(dim=-1).item() print(f"Predicted: {config['id2label'][str(pred)]} ({probs[0][pred]:.1%})") ``` ## Training - 3x pitch-shift augmentation (original + random up [1-4 semitones] + random down [1-4 semitones]) - Tanpura reference pitch shifts with audio, forcing the model to learn relative intervals - BFloat16 mixed precision on RTX 4090 - Cosine annealing LR with warmup - Early stopping with patience=5