kepsmiling121 commited on
Commit
d4a5fb6
·
verified ·
1 Parent(s): df2fec5

Create models/musicgen_model.py

Browse files
Files changed (1) hide show
  1. models/musicgen_model.py +159 -0
models/musicgen_model.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MusicGen model wrapper with advanced features
3
+ """
4
+ import torch
5
+ import numpy as np
6
+ from typing import Optional, Dict, List
7
+ from transformers import AutoProcessor, MusicgenForConditionalGeneration
8
+ import scipy
9
+ import logging
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ class MusicGenModel:
14
+ def __init__(self, model_id: str = "facebook/musicgen-small"):
15
+ self.model_id = model_id
16
+ self.processor = None
17
+ self.model = None
18
+ self.device = None
19
+ self._load_model()
20
+
21
+ def _load_model(self):
22
+ """Load model and processor"""
23
+ try:
24
+ logger.info(f"Loading MusicGen model: {self.model_id}")
25
+
26
+ # Set device
27
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
28
+
29
+ # Load processor and model
30
+ self.processor = AutoProcessor.from_pretrained(self.model_id)
31
+ self.model = MusicgenForConditionalGeneration.from_pretrained(
32
+ self.model_id,
33
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
34
+ )
35
+
36
+ self.model.to(self.device)
37
+ self.model.eval()
38
+
39
+ logger.info(f"Model loaded successfully on {self.device}")
40
+
41
+ except Exception as e:
42
+ logger.error(f"Failed to load model: {str(e)}")
43
+ raise
44
+
45
+ def generate_from_text(
46
+ self,
47
+ prompt: str,
48
+ duration: int = 10,
49
+ guidance_scale: float = 3.0,
50
+ temperature: float = 1.0,
51
+ top_k: int = 50,
52
+ do_sample: bool = True
53
+ ) -> np.ndarray:
54
+ """Generate music from text prompt"""
55
+ try:
56
+ max_new_tokens = int(duration * 50) # Rough conversion
57
+
58
+ inputs = self.processor(
59
+ text=[prompt],
60
+ padding=True,
61
+ return_tensors="pt",
62
+ ).to(self.device)
63
+
64
+ with torch.no_grad():
65
+ audio_values = self.model.generate(
66
+ **inputs,
67
+ do_sample=do_sample,
68
+ guidance_scale=guidance_scale,
69
+ temperature=temperature,
70
+ top_k=top_k,
71
+ max_new_tokens=max_new_tokens
72
+ )
73
+
74
+ return audio_values[0, 0].cpu().numpy()
75
+
76
+ except Exception as e:
77
+ logger.error(f"Text generation failed: {str(e)}")
78
+ raise
79
+
80
+ def generate_from_audio(
81
+ self,
82
+ audio_array: np.ndarray,
83
+ duration: int = 10,
84
+ guidance_scale: float = 3.0
85
+ ) -> np.ndarray:
86
+ """Generate music conditioned on input audio"""
87
+ try:
88
+ max_new_tokens = int(duration * 50)
89
+
90
+ inputs = self.processor(
91
+ audio=audio_array,
92
+ sampling_rate=16000,
93
+ padding=True,
94
+ return_tensors="pt",
95
+ ).to(self.device)
96
+
97
+ with torch.no_grad():
98
+ audio_values = self.model.generate(
99
+ **inputs,
100
+ do_sample=True,
101
+ guidance_scale=guidance_scale,
102
+ max_new_tokens=max_new_tokens
103
+ )
104
+
105
+ return audio_values[0, 0].cpu().numpy()
106
+
107
+ except Exception as e:
108
+ logger.error(f"Audio conditioning failed: {str(e)}")
109
+ raise
110
+
111
+ def generate_from_text_and_audio(
112
+ self,
113
+ prompt: str,
114
+ audio_array: np.ndarray,
115
+ duration: int = 10,
116
+ guidance_scale: float = 3.0
117
+ ) -> np.ndarray:
118
+ """Generate music from both text and audio"""
119
+ try:
120
+ max_new_tokens = int(duration * 50)
121
+
122
+ inputs = self.processor(
123
+ text=[prompt],
124
+ audio=audio_array,
125
+ sampling_rate=16000,
126
+ padding=True,
127
+ return_tensors="pt",
128
+ ).to(self.device)
129
+
130
+ with torch.no_grad():
131
+ audio_values = self.model.generate(
132
+ **inputs,
133
+ do_sample=True,
134
+ guidance_scale=guidance_scale,
135
+ max_new_tokens=max_new_tokens
136
+ )
137
+
138
+ return audio_values[0, 0].cpu().numpy()
139
+
140
+ except Exception as e:
141
+ logger.error(f"Combined generation failed: {str(e)}")
142
+ raise
143
+
144
+ def batch_generate(
145
+ self,
146
+ prompts: List[str],
147
+ duration: int = 10,
148
+ guidance_scale: float = 3.0
149
+ ) -> List[np.ndarray]:
150
+ """Generate multiple music samples"""
151
+ results = []
152
+ for prompt in prompts:
153
+ audio = self.generate_from_text(
154
+ prompt=prompt,
155
+ duration=duration,
156
+ guidance_scale=guidance_scale
157
+ )
158
+ results.append(audio)
159
+ return results