| Para un uso sencillo del modelo utilize el siguiente codigo: | |
| ```py | |
| #! pip install transformers | |
| #! pip install torch | |
| #! pip install datasets | |
| ``` | |
| ```py | |
| from transformers import DistilBertForSequenceClassification, DistilBertTokenizer | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| from datasets import load_dataset | |
| import numpy as np | |
| import torch | |
| dataset = load_dataset("manoh2f2/songs_resampled") | |
| # Cargar el dataset en un DataFrame | |
| split_name = 'train' | |
| df_resampled = dataset[split_name].to_pandas() | |
| tokenizer = AutoTokenizer.from_pretrained("manoh2f2/recommend_songs") | |
| model = AutoModelForSequenceClassification.from_pretrained("manoh2f2/recommend_songs") | |
| # Define a prompt | |
| prompt = "I am happy" | |
| # Tokenize the prompt | |
| encoded_prompt = tokenizer(prompt, return_tensors='pt', max_length=256) | |
| # Make a prediction using the trained model | |
| with torch.no_grad(): | |
| model_output = model(**encoded_prompt) | |
| # Get the predicted emotion index | |
| predicted_emotion_index = torch.argmax(model_output.logits).item() | |
| # Map the index back to the emotion label using the DataFrame | |
| predicted_emotion_label = df_resampled['emotions'].unique()[predicted_emotion_index] | |
| # Get a song associated with the predicted emotion from the DaraFrame | |
| result = df_resampled[df_resampled['emotions'] == predicted_emotion_label] | |
| # Get the number of rows in the DataFrame | |
| num_rows = result.shape[0] | |
| #Generate a random index to select a random song from the DataFrame | |
| random_index = np.random.randint(0, num_rows) | |
| #Get the recommended song and artist | |
| recommended_song = result['song'].iloc[random_index] | |
| recommended_artist = result['artist'].iloc[random_index] | |
| #Print the results | |
| print(f"Prompt: {prompt}") | |
| print(f"Predicted Emotion: {predicted_emotion_label}") | |
| print(f"Recommended Song: {recommended_song} - {recommended_artist}") | |
| ``` |