| """ |
| Test ViralTrack Predictor - Spotify Popularity Prediction |
| Usage: python test_model.py <model_path> |
| """ |
|
|
| import sys |
| import logging |
| import torch |
| from pathlib import Path |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
| |
| logging.getLogger("transformers").setLevel(logging.ERROR) |
| logging.getLogger("torch").setLevel(logging.ERROR) |
|
|
|
|
| def load_model(model_path): |
| """Load trained model""" |
| print(f"β³ Loading model from: {model_path}") |
| tokenizer = AutoTokenizer.from_pretrained(model_path) |
| model = AutoModelForSequenceClassification.from_pretrained(model_path) |
| model.eval() |
| print(f"β Model loaded\n") |
| return model, tokenizer |
|
|
|
|
| def predict_popularity(model, tokenizer, track_name, artists="", audio_features=None): |
| """ |
| Predict popularity for a track |
| |
| Args: |
| model: Trained model |
| tokenizer: Tokenizer |
| track_name: Song title |
| artists: Artist name(s) |
| audio_features: Dict of audio features (danceability, energy, etc.) |
| |
| Returns: |
| popularity_score (0-100), recommendations |
| """ |
| |
| text_parts = [track_name] |
| if artists: |
| text_parts.append(artists) |
| combined_text = ' '.join(text_parts) |
|
|
| |
| numerical = [] |
| if audio_features: |
| for col, val in audio_features.items(): |
| if col not in ['track_name', 'artists']: |
| numerical.append(f"{col}:{float(val):.3f}") |
|
|
| |
| if numerical: |
| input_text = f"{combined_text} | {' '.join(numerical)}" |
| else: |
| input_text = combined_text |
|
|
| |
| inputs = tokenizer(input_text, return_tensors='pt', padding=True, truncation=True, max_length=128) |
|
|
| |
| with torch.no_grad(): |
| outputs = model(**inputs) |
| |
| raw_score = outputs.logits.item() * 100 |
|
|
| |
| raw_score = max(0, min(100, raw_score)) |
|
|
| |
| recommendations = generate_recommendations(raw_score, audio_features or {}) |
|
|
| return raw_score, recommendations |
|
|
|
|
| def generate_recommendations(prediction: float, features: dict) -> list: |
| """Generate actionable recommendations based on prediction and features""" |
| recommendations = [] |
| |
| if prediction < 40: |
| recommendations.append("β οΈ Predicted popularity is LOW - consider these changes:") |
| elif prediction < 60: |
| recommendations.append("π Predicted popularity is MODERATE - optimization opportunities:") |
| elif prediction < 80: |
| recommendations.append("β
Predicted popularity is GOOD - track has solid potential!") |
| else: |
| recommendations.append("π₯ Predicted popularity is HIGH - track has VIRAL potential!") |
| |
| |
| if features.get('duration_ms', 0) > 200000: |
| recommendations.append(" π Song is long (>3:20) - consider shorter version for TikTok/Reels") |
| |
| if features.get('energy', 0) < 0.4: |
| recommendations.append(" β‘ Low energy - consider adding more dynamic elements") |
| |
| if features.get('danceability', 0) < 0.5: |
| recommendations.append(" π Low danceability - may not perform well on social platforms") |
| |
| if features.get('valence', 0) > 0.8: |
| recommendations.append(" π Very positive mood - great for playlists/morning vibes") |
| |
| if features.get('acousticness', 0) > 0.7: |
| recommendations.append(" πΈ Highly acoustic - consider production polish for mainstream appeal") |
| |
| if features.get('speechiness', 0) > 0.3: |
| recommendations.append(" π€ High speechiness - may work well for podcast/hip-hop audiences") |
| |
| if features.get('instrumentalness', 0) > 0.5: |
| recommendations.append(" πΉ Instrumental track - consider adding vocals for broader appeal") |
| |
| if features.get('liveness', 0) > 0.6: |
| recommendations.append(" ποΈ Live recording - studio version may have wider appeal") |
| |
| return recommendations |
|
|
|
|
| def test_model(model_path, test_tracks=None): |
| """Test model with sample tracks""" |
| model, tokenizer = load_model(model_path) |
|
|
| |
| if test_tracks is None: |
| test_tracks = [ |
| { |
| 'track_name': "Bohemian Rhapsody", |
| 'artists': "Queen", |
| 'audio_features': { |
| 'danceability': 0.416, |
| 'energy': 0.489, |
| 'valence': 0.279, |
| 'tempo': 144.0, |
| 'duration_ms': 354947, |
| 'acousticness': 0.172, |
| 'instrumentalness': 0.0, |
| 'liveness': 0.207, |
| 'speechiness': 0.0467, |
| } |
| }, |
| { |
| 'track_name': "Shape of You", |
| 'artists': "Ed Sheeran", |
| 'audio_features': { |
| 'danceability': 0.825, |
| 'energy': 0.652, |
| 'valence': 0.931, |
| 'tempo': 96.0, |
| 'duration_ms': 233713, |
| 'acousticness': 0.581, |
| 'instrumentalness': 0.0, |
| 'liveness': 0.0931, |
| 'speechiness': 0.0802, |
| } |
| }, |
| { |
| 'track_name': "Blinding Lights", |
| 'artists': "The Weeknd", |
| 'audio_features': { |
| 'danceability': 0.514, |
| 'energy': 0.730, |
| 'valence': 0.334, |
| 'tempo': 171.0, |
| 'duration_ms': 200040, |
| 'acousticness': 0.00146, |
| 'instrumentalness': 0.000906, |
| 'liveness': 0.0897, |
| 'speechiness': 0.0598, |
| } |
| }, |
| { |
| 'track_name': "Bad Guy", |
| 'artists': "Billie Eilish", |
| 'audio_features': { |
| 'danceability': 0.703, |
| 'energy': 0.432, |
| 'valence': 0.560, |
| 'tempo': 135.0, |
| 'duration_ms': 194088, |
| 'acousticness': 0.133, |
| 'instrumentalness': 0.000234, |
| 'liveness': 0.0962, |
| 'speechiness': 0.378, |
| } |
| }, |
| { |
| 'track_name': "Old Town Road", |
| 'artists': "Lil Nas X", |
| 'audio_features': { |
| 'danceability': 0.547, |
| 'energy': 0.621, |
| 'valence': 0.645, |
| 'tempo': 136.0, |
| 'duration_ms': 157066, |
| 'acousticness': 0.0395, |
| 'instrumentalness': 0.0, |
| 'liveness': 0.117, |
| 'speechiness': 0.0924, |
| } |
| }, |
| ] |
|
|
| print("\n" + "=" * 70) |
| print("π΅ ViralTrack Predictor - Popularity Prediction & Recommendations") |
| print("=" * 70) |
|
|
| results = [] |
| for track in test_tracks: |
| track_name = track['track_name'] |
| artists = track.get('artists', '') |
| audio_features = track.get('audio_features', {}) |
|
|
| popularity, recommendations = predict_popularity(model, tokenizer, track_name, artists, audio_features) |
| |
| results.append({ |
| 'track_name': track_name, |
| 'predicted_popularity': popularity, |
| 'recommendations': recommendations |
| }) |
| |
| print(f"\nπ΅ Track: '{track_name}'") |
| print(f" Predicted Popularity: {popularity:.1f}/100") |
| print(f"\n Recommendations:") |
| for rec in recommendations: |
| print(f" {rec}") |
|
|
| print("\n" + "=" * 70) |
| print("π Summary") |
| print("=" * 70) |
| for r in results: |
| bar_len = int(r['predicted_popularity'] / 5) |
| bar = "β" * bar_len + "β" * (20 - bar_len) |
| print(f" {r['track_name'][:25]:<25} [{bar}] {r['predicted_popularity']:.1f}") |
|
|
| print("\n" + "=" * 70) |
| print("π‘ Tip: Run with custom track:") |
| print(" python test_model.py <model_path>") |
| print("=" * 70) |
|
|
| return results |
|
|
|
|
| def interactive_mode(model, tokenizer): |
| """Interactive mode for testing custom tracks""" |
| print("\n" + "=" * 70) |
| print("π€ Interactive Mode - Enter track details (or 'quit' to exit)") |
| print("=" * 70) |
|
|
| while True: |
| track_name = input("\nπ΅ Track name: ").strip() |
| if track_name.lower() in ['quit', 'exit', 'q']: |
| break |
|
|
| artists = input(" Artists: ").strip() |
|
|
| |
| use_features = input(" Add audio features? (y/n): ").strip().lower() |
| audio_features = {} |
|
|
| if use_features == 'y': |
| print(" Enter features (or press Enter to skip):") |
| for feat in ['danceability', 'energy', 'valence', 'tempo', 'duration_ms', |
| 'acousticness', 'instrumentalness', 'liveness', 'speechiness']: |
| val = input(f" {feat}: ").strip() |
| if val: |
| try: |
| audio_features[feat] = float(val) |
| except ValueError: |
| pass |
|
|
| popularity, recommendations = predict_popularity(model, tokenizer, track_name, artists, audio_features) |
| |
| print(f"\n π Predicted Popularity: {popularity:.1f}/100") |
| print(f"\n π‘ Recommendations:") |
| for rec in recommendations: |
| print(f" {rec}") |
|
|
|
|
| if __name__ == '__main__': |
| if len(sys.argv) < 2: |
| model_path = 'model' |
| else: |
| model_path = sys.argv[1] |
|
|
| if not Path(model_path).exists(): |
| print(f"β Error: Model not found at {model_path}") |
| print(" Run training first: ./run.sh config") |
| sys.exit(1) |
|
|
| model, tokenizer = load_model(model_path) |
|
|
| |
| test_model(model_path) |
|
|