ChuT Compressed Model — Knowledge Distillation (Kiran)

Compressed version of the ChuT Late Fusion TB detection model using Knowledge Distillation.

Author: Nannyombi Shakiran B.
Technique: Knowledge Distillation (Hinton et al., 2015)
Original model: TB42project/Late_fusion_model

Results

Metric Teacher Student
AUROC 0.9849 0.9195
Compression 17.02x smaller
AUROC Retention 93.4%

How to Use

import joblib, torch
import numpy as np
import librosa

# Load the student bundle
bundle = joblib.load('student_fusion_model.pkl')

# --- Audio prediction ---
# 1. Load and preprocess audio
audio, _ = librosa.load('cough.wav', sr=22050, duration=5)
audio = np.pad(audio, (0, max(0, 22050*5 - len(audio))))[:22050*5]
mel = librosa.feature.melspectrogram(y=audio, sr=22050, n_mels=128, n_fft=2048, hop_length=512)
log_mel = librosa.power_to_db(mel, ref=np.max)
log_mel = (log_mel - log_mel.min()) / (log_mel.max() - log_mel.min() + 1e-8)
mel_tensor = torch.FloatTensor(log_mel).unsqueeze(0).unsqueeze(0)  # (1,1,128,T)

# 2. Load student audio model (define StudentAudioCNN first — see repo)
# audio_prob = torch.sigmoid(student_audio(mel_tensor)).item()

# --- Clinical prediction ---
import pandas as pd
features = ['sex','age','height','weight','reported_cough_dur','hemoptysis',
            'weight_loss','fever','night_sweats','smoke_lweek','heart_rate',
            'temperature','tb_prior','tb_prior_Pul','tb_prior_Extrapul','tb_prior_Unknown']
clinical_data = pd.DataFrame([[1,25,170,65,14,0,1,1,0,0,80,37.2,0,0,0,0]], columns=features)
clinical_scaled = bundle['clinical_scaler'].transform(clinical_data)
clinical_prob = bundle['clinical_model'].predict_proba(clinical_scaled)[0][1]

# --- Fusion ---
# final_prob = 0.28 * audio_prob + 0.72 * clinical_prob
# prediction = 'TB Positive' if final_prob > 0.5 else 'TB Negative'

## Method
Offline Knowledge Distillation (Hinton et al., 2015)
- Temperature T = 4.0
- Alpha α = 0.7
- Loss = α × T² × KL(teacher_soft ∥ student_soft) + (1−α) × BCE(hard_label)

## Student Architecture
- Audio: LightTBCNN (2 ResBlocks, 64 filters)
- Clinical: MLP (3216)
- Fusion: Late fusion (28% audio + 72% clinical)

⚠️ Research screening tool only — not a clinical diagnostic device.
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Space using TB42project/ChuT-KnowledgeDistillation-Kiran 1