|
|
""" |
|
|
Example usage script for T5 Spotify Features model |
|
|
""" |
|
|
from transformers import T5ForConditionalGeneration, T5Tokenizer |
|
|
import json |
|
|
|
|
|
def load_model(): |
|
|
"""Load the model and tokenizer""" |
|
|
model = T5ForConditionalGeneration.from_pretrained("synyyy/t5-spotify-features-v2") |
|
|
tokenizer = T5Tokenizer.from_pretrained("synyyy/t5-spotify-features-v2") |
|
|
return model, tokenizer |
|
|
|
|
|
def generate_spotify_features(prompt, model, tokenizer): |
|
|
"""Generate Spotify features from text prompt""" |
|
|
input_text = f"prompt: {prompt}" |
|
|
|
|
|
input_ids = tokenizer(input_text, return_tensors="pt", max_length=256, truncation=True).input_ids |
|
|
outputs = model.generate( |
|
|
input_ids, |
|
|
max_length=256, |
|
|
num_beams=4, |
|
|
early_stopping=True, |
|
|
do_sample=False |
|
|
) |
|
|
|
|
|
result = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
if not result.strip().startswith(') and not result.strip().endswith('): |
|
|
result = " + result + " |
|
|
|
|
|
try: |
|
|
return json.loads(result) |
|
|
except json.JSONDecodeError as e: |
|
|
print(f"JSON parsing failed: {e}") |
|
|
print(f"Raw output: {result}") |
|
|
return None |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
print("Loading model...") |
|
|
model, tokenizer = load_model() |
|
|
|
|
|
|
|
|
test_prompts = [ |
|
|
"energetic dance music for parties", |
|
|
"calm acoustic music for studying", |
|
|
"upbeat pop songs for working out", |
|
|
"relaxing instrumental background music", |
|
|
"happy music for road trips" |
|
|
] |
|
|
|
|
|
print("\nGenerating features for test prompts:") |
|
|
print("=" * 50) |
|
|
|
|
|
for prompt in test_prompts: |
|
|
print(f"\nPrompt: {prompt}") |
|
|
features = generate_spotify_features(prompt, model, tokenizer) |
|
|
if features: |
|
|
print(f"Features: {json.dumps(features, indent=2)}") |
|
|
else: |
|
|
print("Failed to generate valid features") |
|
|
print("-" * 30) |
|
|
|