t5-spotify-features-v2 / example_usage.py
synyyy's picture
Add example usage script
942de71 verified
"""
Example usage script for T5 Spotify Features model
"""
from transformers import T5ForConditionalGeneration, T5Tokenizer
import json
def load_model():
"""Load the model and tokenizer"""
model = T5ForConditionalGeneration.from_pretrained("synyyy/t5-spotify-features-v2")
tokenizer = T5Tokenizer.from_pretrained("synyyy/t5-spotify-features-v2")
return model, tokenizer
def generate_spotify_features(prompt, model, tokenizer):
"""Generate Spotify features from text prompt"""
input_text = f"prompt: {prompt}"
input_ids = tokenizer(input_text, return_tensors="pt", max_length=256, truncation=True).input_ids
outputs = model.generate(
input_ids,
max_length=256,
num_beams=4,
early_stopping=True,
do_sample=False
)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Post-process JSON if needed
if not result.strip().startswith(') and not result.strip().endswith('):
result = " + result + "
try:
return json.loads(result)
except json.JSONDecodeError as e:
print(f"JSON parsing failed: {e}")
print(f"Raw output: {result}")
return None
if __name__ == "__main__":
# Load model
print("Loading model...")
model, tokenizer = load_model()
# Test prompts
test_prompts = [
"energetic dance music for parties",
"calm acoustic music for studying",
"upbeat pop songs for working out",
"relaxing instrumental background music",
"happy music for road trips"
]
print("\nGenerating features for test prompts:")
print("=" * 50)
for prompt in test_prompts:
print(f"\nPrompt: {prompt}")
features = generate_spotify_features(prompt, model, tokenizer)
if features:
print(f"Features: {json.dumps(features, indent=2)}")
else:
print("Failed to generate valid features")
print("-" * 30)