afsagag commited on
Commit
9d74db1
·
verified ·
1 Parent(s): 3e8322b

Add example usage script

Browse files
Files changed (1) hide show
  1. example_usage.py +58 -0
example_usage.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Example script for using the T5 Spotify Features model
3
+ """
4
+ from transformers import T5ForConditionalGeneration, T5Tokenizer
5
+ import json
6
+
7
+ def predict_spotify_features(prompt_text, model_name="afsagag/t5-spotify-features"):
8
+ """
9
+ Generate Spotify audio features from a text prompt
10
+
11
+ Args:
12
+ prompt_text (str): Natural language description of music preferences
13
+ model_name (str): Hugging Face model name
14
+
15
+ Returns:
16
+ dict: Spotify audio features or None if JSON parsing fails
17
+ """
18
+ # Load model and tokenizer
19
+ model = T5ForConditionalGeneration.from_pretrained(model_name)
20
+ tokenizer = T5Tokenizer.from_pretrained(model_name)
21
+
22
+ # Format input
23
+ input_text = f"prompt: {prompt_text}"
24
+
25
+ # Tokenize and generate
26
+ input_ids = tokenizer(input_text, return_tensors="pt", max_length=256, truncation=True).input_ids
27
+ outputs = model.generate(
28
+ input_ids,
29
+ max_length=256,
30
+ num_beams=4,
31
+ early_stopping=True,
32
+ do_sample=False
33
+ )
34
+
35
+ # Decode and clean result
36
+ result = tokenizer.decode(outputs[0], skip_special_tokens=True)
37
+ cleaned_result = result.replace("ll", "null").replace("nu", "null")
38
+
39
+ try:
40
+ return json.loads(cleaned_result)
41
+ except json.JSONDecodeError:
42
+ print(f"Failed to parse JSON: {cleaned_result}")
43
+ return None
44
+
45
+ if __name__ == "__main__":
46
+ # Example prompts
47
+ test_prompts = [
48
+ "I want energetic dance music",
49
+ "Play some calm acoustic songs",
50
+ "Upbeat pop music for working out",
51
+ "Sad slow songs for rainy days"
52
+ ]
53
+
54
+ for prompt in test_prompts:
55
+ print(f"\nPrompt: {prompt}")
56
+ features = predict_spotify_features(prompt)
57
+ if features:
58
+ print(f"Features: {json.dumps(features, indent=2)}")