Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -39,6 +39,7 @@ def clean_text(text: str) -> str:
|
|
| 39 |
"""
|
| 40 |
Removes undesired characters (e.g., asterisks) that might not be recognized by the model's vocabulary.
|
| 41 |
"""
|
|
|
|
| 42 |
return re.sub(r'\*', '', text)
|
| 43 |
|
| 44 |
# ---------------------------------------------------------------------
|
|
@@ -74,6 +75,7 @@ def get_musicgen_model(model_key: str = "facebook/musicgen-large"):
|
|
| 74 |
|
| 75 |
model = MusicgenForConditionalGeneration.from_pretrained(model_key)
|
| 76 |
processor = AutoProcessor.from_pretrained(model_key)
|
|
|
|
| 77 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 78 |
model.to(device)
|
| 79 |
MUSICGEN_MODELS[model_key] = (model, processor)
|
|
@@ -203,9 +205,7 @@ def generate_music(prompt: str, audio_length: int):
|
|
| 203 |
musicgen_model, musicgen_processor = get_musicgen_model(model_key)
|
| 204 |
|
| 205 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 206 |
-
|
| 207 |
-
inputs = musicgen_processor(text=[prompt], padding=True, return_tensors="pt")
|
| 208 |
-
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 209 |
|
| 210 |
with torch.inference_mode():
|
| 211 |
outputs = musicgen_model.generate(**inputs, max_new_tokens=audio_length)
|
|
|
|
| 39 |
"""
|
| 40 |
Removes undesired characters (e.g., asterisks) that might not be recognized by the model's vocabulary.
|
| 41 |
"""
|
| 42 |
+
# Remove all asterisks. You can add more cleaning steps here as needed.
|
| 43 |
return re.sub(r'\*', '', text)
|
| 44 |
|
| 45 |
# ---------------------------------------------------------------------
|
|
|
|
| 75 |
|
| 76 |
model = MusicgenForConditionalGeneration.from_pretrained(model_key)
|
| 77 |
processor = AutoProcessor.from_pretrained(model_key)
|
| 78 |
+
|
| 79 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 80 |
model.to(device)
|
| 81 |
MUSICGEN_MODELS[model_key] = (model, processor)
|
|
|
|
| 205 |
musicgen_model, musicgen_processor = get_musicgen_model(model_key)
|
| 206 |
|
| 207 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 208 |
+
inputs = musicgen_processor(text=[prompt], padding=True, return_tensors="pt").to(device)
|
|
|
|
|
|
|
| 209 |
|
| 210 |
with torch.inference_mode():
|
| 211 |
outputs = musicgen_model.generate(**inputs, max_new_tokens=audio_length)
|