jacob1576 commited on
Commit
cffc5b3
·
1 Parent(s): 28e5197

Updated pip requirements and added code to load model from HF hub

Browse files
Files changed (3) hide show
  1. app.py +10 -4
  2. config.yaml +48 -0
  3. requirements.txt +5 -0
app.py CHANGED
@@ -17,7 +17,8 @@ import matplotlib.pyplot as plt
17
  from pathlib import Path
18
 
19
  from demucs import pretrained
20
- from transformers import ClapModel, AutoTokenizer
 
21
 
22
  from src.models.stem_separation.ATHTDemucs_v2 import AudioTextHTDemucs
23
  from utils import load_config, plot_spectrogram
@@ -27,7 +28,6 @@ from utils import load_config, plot_spectrogram
27
  # ============================================================================
28
 
29
  cfg = load_config("config.yaml")
30
- CHECKPOINT_PATH = cfg["training"]["resume_from"] # Change as needed
31
  SAMPLE_RATE = cfg["data"]["sample_rate"]
32
  SEGMENT_SECONDS = cfg["data"]["segment_seconds"]
33
  OVERLAP = cfg["data"]["overlap"]
@@ -41,6 +41,12 @@ else:
41
  DEVICE = "cpu"
42
  # DEVICE = "cpu"
43
 
 
 
 
 
 
 
44
 
45
  # ============================================================================
46
  # Model Loading
@@ -57,8 +63,8 @@ tokenizer = AutoTokenizer.from_pretrained("laion/clap-htsat-unfused")
57
  print("Building AudioTextHTDemucs...")
58
  model = AudioTextHTDemucs(htdemucs, clap, tokenizer)
59
 
60
- print(f"Loading checkpoint from {CHECKPOINT_PATH}...")
61
- checkpoint = torch.load(CHECKPOINT_PATH, map_location="cpu")
62
  model.load_state_dict(checkpoint["model_state_dict"], strict=False)
63
  print(f"Loaded checkpoint from epoch {checkpoint.get('epoch', '?')}")
64
 
 
17
  from pathlib import Path
18
 
19
  from demucs import pretrained
20
+ from transformers import ClapModel, AutoTokenizer, AutoModel
21
+ from huggingface_hub import hf_hub_download
22
 
23
  from src.models.stem_separation.ATHTDemucs_v2 import AudioTextHTDemucs
24
  from utils import load_config, plot_spectrogram
 
28
  # ============================================================================
29
 
30
  cfg = load_config("config.yaml")
 
31
  SAMPLE_RATE = cfg["data"]["sample_rate"]
32
  SEGMENT_SECONDS = cfg["data"]["segment_seconds"]
33
  OVERLAP = cfg["data"]["overlap"]
 
41
  DEVICE = "cpu"
42
  # DEVICE = "cpu"
43
 
44
+ # Load model HuggingFace Hub
45
+ # TODO: Add our model to the AutoModel inferface
46
+ ckpt = hf_hub_download(
47
+ repo_id="jacob1576/AudioTextHTDemucs",
48
+ filename="best_model.pt"
49
+ )
50
 
51
  # ============================================================================
52
  # Model Loading
 
63
  print("Building AudioTextHTDemucs...")
64
  model = AudioTextHTDemucs(htdemucs, clap, tokenizer)
65
 
66
+ print(f"Loading checkpoint from HuggingFace Hub...")
67
+ checkpoint = torch.load(ckpt, map_location="cpu")
68
  model.load_state_dict(checkpoint["model_state_dict"], strict=False)
69
  print(f"Loaded checkpoint from epoch {checkpoint.get('epoch', '?')}")
70
 
config.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ train_dir: /home/jacob/datasets/musdb18/train # Path to train subfolder of MUSDB18 dataset
3
+ test_dir: /home/jacob/datasets/musdb18/test # Path to test subfolder of MUSDB18 dataset
4
+ segment_seconds: 6.0 # Length of audio segments for training [s]
5
+ pct_train: 0.2 # Decimal percentage of full data to use for training (otherwise 1 epoch takes ~15 hrs)
6
+ pct_test: 0.1 # Decimal percentage of full data to use for testing
7
+ overlap: 0.1 # Overlap between segments for chunked inference [s]
8
+ sample_rate: 44100 # Sample rate for audio files [Hz]
9
+ channels: 2 # Number of audio channels (1 = mono, 2 = stereo)
10
+ random_segments: False # Whether to use random segments during training
11
+ augment: True # Whether to use data augmentation (gain adjustment and channel swapping)
12
+
13
+ model:
14
+ name: Audio-Text-HTDemucs # Model name
15
+ model_dim: 384 # Model dimension
16
+ text_dim: 512 # Text embedding dimension (laion/clap-htsat-unfused is 512)
17
+ num_heads: 8 # Number of attention heads for text cross-attention layer
18
+ device: cpu # Device to use for training (cuda for GPU or cpu)
19
+ use_amp: False # Whether to use automatic mixed precision (AMP) during training - WORK IN PROGRESS
20
+
21
+ training:
22
+ batch_size: 8 # Batch size for training
23
+ num_workers: 0 # Number of DataLoader workers
24
+ num_epochs: 20 # Number of training epochs
25
+ optimizer:
26
+ name: AdamW
27
+ lr: 1e-4 # Learning rate
28
+ weight_decay: 1e-2 # Weight decay for optimizer
29
+ grad_clip: 5.0 # Gradient clipping value (set to null to disable)
30
+ loss_weights:
31
+ sdr: 0.9 # Weight for SDR loss
32
+ sisdr_weight: 0.1 # Weight for SI-SDR loss, total loss is (sdr_weight * sdr) + (sisdr_weight * si_sdr)
33
+ use_L1_comb_loss: False # Whether to use L1 combination loss
34
+ L1_comb_loss:
35
+ sdr_weight: 1.0 # Weight for SDR in L1 combination loss
36
+ l1_weight: 0.1 # Weight for L1 loss in L1 combination loss
37
+ #resume_from: null # Path to checkpoint to resume training from (set to null to train from scratch)
38
+ resume_from: checkpoints/2025_11_30_batch4/best_model.pt
39
+
40
+ wandb:
41
+ use_wandb: False # Whether to use Weights & Biases for experiment tracking
42
+ project: audio-text-htdemucs # Wandb project name
43
+ run_name: null
44
+ log_every: 50 # Log to wandb every N batches
45
+ validate_every: 1 # Validate every N epochs
46
+ save_every: 5 # Save model checkpoint every N epochs
47
+ checkpoint_dir: checkpoints/2025_12_06/ # Directory to save model checkpoints
48
+ output_dir: results/2025_12_06 # Directory to save inference results
requirements.txt CHANGED
@@ -1,3 +1,5 @@
 
 
1
  demucs==4.0.1
2
  gradio==5.17.1
3
  huggingface_hub
@@ -6,6 +8,9 @@ librosa
6
  loralib
7
  matplotlib==3.10.1
8
  numpy==2.1.3
 
 
 
9
  pathlib
10
  pydantic==2.10.6
11
  soundfile==0.13.1
 
1
+ --extra-index-url https://download.pytorch.org/whl/cpu
2
+
3
  demucs==4.0.1
4
  gradio==5.17.1
5
  huggingface_hub
 
8
  loralib
9
  matplotlib==3.10.1
10
  numpy==2.1.3
11
+ torch==2.6.0
12
+ torchvision==0.21.0
13
+ torchaudio==2.6.0
14
  pathlib
15
  pydantic==2.10.6
16
  soundfile==0.13.1