Spaces:
Sleeping
Sleeping
Upload taylor's_version_data_science_–_assignment_3_(group_project).py
Browse files
taylor's_version_data_science_–_assignment_3_(group_project).py
ADDED
|
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""Taylor's Version Data Science – Assignment 3 (Group Project).ipynb
|
| 3 |
+
|
| 4 |
+
Automatically generated by Colab.
|
| 5 |
+
|
| 6 |
+
Original file is located at
|
| 7 |
+
https://colab.research.google.com/drive/1UINZIWnD3atFdcs5TWUxZ_Af7c5gLhrZ
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
!pip install -q datasets huggingface_hub
|
| 11 |
+
!pip install -U datasets
|
| 12 |
+
|
| 13 |
+
from google.colab import userdata
|
| 14 |
+
from huggingface_hub import login
|
| 15 |
+
|
| 16 |
+
hf_token = userdata.get("HF_TOKEN") # מושך את הטוקן השמור בסיקרט
|
| 17 |
+
login(hf_token)
|
| 18 |
+
|
| 19 |
+
import pandas as pd
|
| 20 |
+
|
| 21 |
+
splits = {'train': 'train_clean.csv', 'validation': 'val_clean.csv', 'test': 'test_clean.csv'}
|
| 22 |
+
df = pd.read_csv("hf://datasets/hita/social-behavior-emotions/" + splits["train"])
|
| 23 |
+
|
| 24 |
+
df.head(10)
|
| 25 |
+
|
| 26 |
+
# מנקה את שם העמודה עם שגיאת הכתיב (אם צריך)
|
| 27 |
+
df = df.rename(columns={"lable": "Label"})
|
| 28 |
+
|
| 29 |
+
# הצגת מידע כללי על הדאטאסט
|
| 30 |
+
print("=== Dataset Info ===")
|
| 31 |
+
print(df.info())
|
| 32 |
+
|
| 33 |
+
print("\n=== Missing Values ===")
|
| 34 |
+
print(df.isnull().sum())
|
| 35 |
+
|
| 36 |
+
print("\n=== Duplicate Rows ===")
|
| 37 |
+
print(df.duplicated().sum())
|
| 38 |
+
|
| 39 |
+
print("\n=== Label Distribution ===")
|
| 40 |
+
print(df['Label'].value_counts())
|
| 41 |
+
|
| 42 |
+
# גרף התפלגות תוויות
|
| 43 |
+
import matplotlib.pyplot as plt
|
| 44 |
+
import seaborn as sns
|
| 45 |
+
|
| 46 |
+
plt.figure(figsize=(8, 5))
|
| 47 |
+
sns.countplot(data=df, x='Label', order=df['Label'].value_counts().index)
|
| 48 |
+
plt.title("Distribution of Emotion Labels")
|
| 49 |
+
plt.xlabel("Emotion")
|
| 50 |
+
plt.ylabel("Count")
|
| 51 |
+
plt.xticks(rotation=45)
|
| 52 |
+
plt.grid(axis="y", linestyle="--", alpha=0.6)
|
| 53 |
+
plt.tight_layout()
|
| 54 |
+
plt.show()
|
| 55 |
+
|
| 56 |
+
# חישוב אורך טקסט במילים
|
| 57 |
+
df['text_length'] = df['CleanReview'].apply(lambda x: len(str(x).split()))
|
| 58 |
+
|
| 59 |
+
plt.figure(figsize=(8, 5))
|
| 60 |
+
sns.histplot(df['text_length'], bins=30, kde=True)
|
| 61 |
+
plt.title("Distribution of Text Lengths (Word Count)")
|
| 62 |
+
plt.xlabel("Number of Words")
|
| 63 |
+
plt.ylabel("Frequency")
|
| 64 |
+
plt.tight_layout()
|
| 65 |
+
plt.show()
|
| 66 |
+
|
| 67 |
+
from collections import Counter
|
| 68 |
+
import matplotlib.pyplot as plt
|
| 69 |
+
import seaborn as sns
|
| 70 |
+
|
| 71 |
+
# מפרק את כל הטקסטים למילים (ללא NLTK)
|
| 72 |
+
all_words = " ".join(df['CleanReview'].dropna()).lower().split()
|
| 73 |
+
word_freq = Counter(all_words)
|
| 74 |
+
|
| 75 |
+
# מציג את 20 המילים הכי נפוצות
|
| 76 |
+
top_words = word_freq.most_common(20)
|
| 77 |
+
words, counts = zip(*top_words)
|
| 78 |
+
|
| 79 |
+
# גרף
|
| 80 |
+
plt.figure(figsize=(10, 5))
|
| 81 |
+
sns.barplot(x=list(counts), y=list(words))
|
| 82 |
+
plt.title("Top 20 Most Common Words")
|
| 83 |
+
plt.xlabel("Frequency")
|
| 84 |
+
plt.ylabel("Word")
|
| 85 |
+
plt.tight_layout()
|
| 86 |
+
plt.show()
|
| 87 |
+
|
| 88 |
+
from collections import Counter
|
| 89 |
+
|
| 90 |
+
# מחשב את מספר המילים בכל שורה
|
| 91 |
+
word_counts = df['CleanReview'].dropna().apply(lambda x: len(str(x).split()))
|
| 92 |
+
|
| 93 |
+
# סופר כמה שורות יש לכל מספר מילים
|
| 94 |
+
length_distribution = Counter(word_counts)
|
| 95 |
+
|
| 96 |
+
# ממיר ל-DataFrame מסודר
|
| 97 |
+
import pandas as pd
|
| 98 |
+
length_df = pd.DataFrame(length_distribution.items(), columns=["Word Count", "Sentence Count"])
|
| 99 |
+
length_df = length_df.sort_values(by="Word Count").reset_index(drop=True)
|
| 100 |
+
|
| 101 |
+
# מציג
|
| 102 |
+
print(length_df)
|
| 103 |
+
|
| 104 |
+
df = df[df['text_length'] > 1].reset_index(drop=True)
|
| 105 |
+
from collections import Counter
|
| 106 |
+
|
| 107 |
+
# מחשב את מספר המילים בכל שורה
|
| 108 |
+
word_counts = df['CleanReview'].dropna().apply(lambda x: len(str(x).split()))
|
| 109 |
+
|
| 110 |
+
# סופר כמה שורות יש לכל מספר מילים
|
| 111 |
+
length_distribution = Counter(word_counts)
|
| 112 |
+
|
| 113 |
+
# ממיר ל-DataFrame מסודר
|
| 114 |
+
import pandas as pd
|
| 115 |
+
length_df = pd.DataFrame(length_distribution.items(), columns=["Word Count", "Sentence Count"])
|
| 116 |
+
length_df = length_df.sort_values(by="Word Count").reset_index(drop=True)
|
| 117 |
+
|
| 118 |
+
# מציג
|
| 119 |
+
print(length_df)
|
| 120 |
+
|
| 121 |
+
from collections import Counter
|
| 122 |
+
import matplotlib.pyplot as plt
|
| 123 |
+
import seaborn as sns
|
| 124 |
+
from wordcloud import WordCloud
|
| 125 |
+
from sklearn.feature_extraction.text import ENGLISH_STOP_WORDS
|
| 126 |
+
|
| 127 |
+
# 1. הסרת שורות כפולות
|
| 128 |
+
df = df.drop_duplicates()
|
| 129 |
+
print(f"✅ DataFrame shape after dropping duplicates: {df.shape}")
|
| 130 |
+
|
| 131 |
+
# 2. ניתוח אורך טקסט לפי רגש
|
| 132 |
+
sns.boxplot(x='Label', y='text_length', data=df)
|
| 133 |
+
plt.title("Text Length by Emotion")
|
| 134 |
+
plt.xlabel("Emotion")
|
| 135 |
+
plt.ylabel("Number of Words")
|
| 136 |
+
plt.tight_layout()
|
| 137 |
+
plt.show()
|
| 138 |
+
|
| 139 |
+
# 3. WordCloud לכל רגש
|
| 140 |
+
emotions = df['Label'].unique()
|
| 141 |
+
|
| 142 |
+
for emotion in emotions:
|
| 143 |
+
text = " ".join(df[df['Label'] == emotion]['CleanReview'].dropna()).lower()
|
| 144 |
+
text = " ".join([word for word in text.split() if word not in ENGLISH_STOP_WORDS])
|
| 145 |
+
|
| 146 |
+
wordcloud = WordCloud(width=800, height=400, background_color='white').generate(text)
|
| 147 |
+
|
| 148 |
+
plt.figure(figsize=(10, 5))
|
| 149 |
+
plt.imshow(wordcloud, interpolation='bilinear')
|
| 150 |
+
plt.axis("off")
|
| 151 |
+
plt.title(f"WordCloud for '{emotion}'")
|
| 152 |
+
plt.show()
|
| 153 |
+
|
| 154 |
+
from sklearn.feature_extraction.text import ENGLISH_STOP_WORDS
|
| 155 |
+
from wordcloud import WordCloud
|
| 156 |
+
import matplotlib.pyplot as plt
|
| 157 |
+
import seaborn as sns
|
| 158 |
+
|
| 159 |
+
# 1. מגדיר רשימת stopwords מותאמת
|
| 160 |
+
custom_stopwords = ENGLISH_STOP_WORDS.union({'feel', 'know', 'make', 'think', 'time','really','want'})
|
| 161 |
+
|
| 162 |
+
# 2. WordCloud לפי כל רגש (ללא stopwords)
|
| 163 |
+
print("✅ WordClouds per emotion (without stopwords):")
|
| 164 |
+
emotions = df['Label'].unique()
|
| 165 |
+
|
| 166 |
+
for emotion in emotions:
|
| 167 |
+
text = " ".join(df[df['Label'] == emotion]['CleanReview'].dropna()).lower()
|
| 168 |
+
words = [word for word in text.split() if word not in custom_stopwords]
|
| 169 |
+
filtered_text = " ".join(words)
|
| 170 |
+
|
| 171 |
+
wordcloud = WordCloud(width=800, height=400, background_color='white').generate(filtered_text)
|
| 172 |
+
|
| 173 |
+
plt.figure(figsize=(10, 5))
|
| 174 |
+
plt.imshow(wordcloud, interpolation='bilinear')
|
| 175 |
+
plt.axis("off")
|
| 176 |
+
plt.title(f"WordCloud for '{emotion}' (No Stopwords)")
|
| 177 |
+
plt.show()
|
| 178 |
+
|
| 179 |
+
# 3. נרמול רגשות (undersampling)
|
| 180 |
+
min_count = df['Label'].value_counts().min()
|
| 181 |
+
|
| 182 |
+
df_balanced = df.groupby('Label').apply(lambda x: x.sample(min_count, random_state=42)).reset_index(drop=True)
|
| 183 |
+
|
| 184 |
+
print("\n✅ New label distribution after balancing:")
|
| 185 |
+
print(df_balanced['Label'].value_counts())
|
| 186 |
+
|
| 187 |
+
"""# **2. Embeddings**"""
|
| 188 |
+
|
| 189 |
+
!pip install -q datasets sentence-transformers umap-learn scikit-learn matplotlib pandas
|
| 190 |
+
!pip install -q scikit-learn umap-learn
|
| 191 |
+
|
| 192 |
+
# שלב 2: יצירת embeddings
|
| 193 |
+
from sentence_transformers import SentenceTransformer
|
| 194 |
+
|
| 195 |
+
texts = list(df["CleanReview"])
|
| 196 |
+
labels = df["Label"]
|
| 197 |
+
label_to_id = {label: idx for idx, label in enumerate(set(labels))}
|
| 198 |
+
true_ids = [label_to_id[label] for label in labels]
|
| 199 |
+
model = SentenceTransformer("j-hartmann/emotion-english-distilroberta-base")
|
| 200 |
+
embeddings = model.encode(texts, show_progress_bar=True)
|
| 201 |
+
|
| 202 |
+
import umap.umap_ as umap
|
| 203 |
+
|
| 204 |
+
# הפלה למרחב דו-ממדי (ויזואליזציה בלבד)
|
| 205 |
+
reducer = umap.UMAP(n_components=2, random_state=42)
|
| 206 |
+
embeddings_2d = reducer.fit_transform(embeddings)
|
| 207 |
+
|
| 208 |
+
from sklearn.cluster import KMeans
|
| 209 |
+
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score
|
| 210 |
+
|
| 211 |
+
# KMeans
|
| 212 |
+
kmeans = KMeans(n_clusters=len(set(labels)), random_state=42)
|
| 213 |
+
kmeans_labels = kmeans.fit_predict(embeddings)
|
| 214 |
+
|
| 215 |
+
# מדדי השוואה
|
| 216 |
+
ari_kmeans = adjusted_rand_score(true_ids, kmeans_labels)
|
| 217 |
+
nmi_kmeans = normalized_mutual_info_score(true_ids, kmeans_labels)
|
| 218 |
+
|
| 219 |
+
# הדפסת תוצאות
|
| 220 |
+
print(f"KMeans:\n - ARI: {ari_kmeans:.3f}\n - NMI: {nmi_kmeans:.3f}")
|
| 221 |
+
|
| 222 |
+
from sklearn.mixture import GaussianMixture
|
| 223 |
+
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score
|
| 224 |
+
|
| 225 |
+
# הגדרת מספר הקלאסטרים (כמו מספר הרגשות)
|
| 226 |
+
n_components = len(set(true_ids))
|
| 227 |
+
|
| 228 |
+
# הרצת GMM
|
| 229 |
+
gmm = GaussianMixture(n_components=n_components, random_state=42)
|
| 230 |
+
gmm_labels = gmm.fit_predict(embeddings)
|
| 231 |
+
|
| 232 |
+
# מדדי התאמה
|
| 233 |
+
ari_gmm = adjusted_rand_score(true_ids, gmm_labels)
|
| 234 |
+
nmi_gmm = normalized_mutual_info_score(true_ids, gmm_labels)
|
| 235 |
+
|
| 236 |
+
print(f"GMM:\n - ARI: {ari_gmm:.3f}\n - NMI: {nmi_gmm:.3f}")
|
| 237 |
+
|
| 238 |
+
import matplotlib.pyplot as plt
|
| 239 |
+
import pandas as pd
|
| 240 |
+
|
| 241 |
+
df_plot = pd.DataFrame(embeddings_2d, columns=["x", "y"])
|
| 242 |
+
df_plot["true"] = true_ids
|
| 243 |
+
df_plot["kmeans"] = kmeans_labels
|
| 244 |
+
df_plot["gmm"] = gmm_labels
|
| 245 |
+
|
| 246 |
+
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
|
| 247 |
+
|
| 248 |
+
# כל גרף מקבל צבעים רק מהעמודה הרלוונטית
|
| 249 |
+
scatter0 = axes[0].scatter(df_plot["x"], df_plot["y"], c=df_plot["true"], cmap="tab10", alpha=0.6)
|
| 250 |
+
axes[0].set_title("True Labels")
|
| 251 |
+
|
| 252 |
+
scatter1 = axes[1].scatter(df_plot["x"], df_plot["y"], c=df_plot["kmeans"], cmap="tab10", alpha=0.6)
|
| 253 |
+
axes[1].set_title("KMeans (mpnet)")
|
| 254 |
+
|
| 255 |
+
scatter2 = axes[2].scatter(df_plot["x"], df_plot["y"], c=df_plot["gmm"], cmap="tab10", alpha=0.6)
|
| 256 |
+
axes[2].set_title("GMM (mpnet)")
|
| 257 |
+
|
| 258 |
+
plt.tight_layout()
|
| 259 |
+
plt.show()
|
| 260 |
+
|
| 261 |
+
from collections import Counter
|
| 262 |
+
|
| 263 |
+
cluster_to_label = {}
|
| 264 |
+
|
| 265 |
+
for cluster_id in range(gmm.n_components):
|
| 266 |
+
# מוצא את כל האינדקסים של הדוגמאות ששייכות לקלאסטר הזה
|
| 267 |
+
indices = [i for i, c in enumerate(gmm_labels) if c == cluster_id]
|
| 268 |
+
|
| 269 |
+
# מוצא את הרגש הכי נפוץ בתוכם
|
| 270 |
+
majority_label_id = Counter([true_ids[i] for i in indices]).most_common(1)[0][0]
|
| 271 |
+
|
| 272 |
+
# שומר במילון
|
| 273 |
+
cluster_to_label[cluster_id] = majority_label_id
|
| 274 |
+
|
| 275 |
+
# מילון שמתאים קלאסטר לרגש
|
| 276 |
+
id_to_label = {v: k for k, v in label_to_id.items()}
|
| 277 |
+
cluster_to_emotion = {cluster: id_to_label[label_id] for cluster, label_id in cluster_to_label.items()}
|
| 278 |
+
|
| 279 |
+
print("מיפוי קלאסטרים → רגשות:")
|
| 280 |
+
print(cluster_to_emotion)
|
| 281 |
+
|
| 282 |
+
import joblib
|
| 283 |
+
|
| 284 |
+
# שמירת מודל ה-embedding (שם המודל עצמו)
|
| 285 |
+
embedding_model_name = "j-hartmann/emotion-english-distilroberta-base"
|
| 286 |
+
|
| 287 |
+
# שמירת מודל ה-GMM
|
| 288 |
+
joblib.dump(gmm, "gmm_model.pkl")
|
| 289 |
+
|
| 290 |
+
# שמירת מיפוי קלאסטר לרגש
|
| 291 |
+
import json
|
| 292 |
+
with open("cluster_to_emotion.json", "w") as f:
|
| 293 |
+
json.dump(cluster_to_emotion, f)
|
| 294 |
+
|
| 295 |
+
"""# 3. Inputs & outputs."""
|
| 296 |
+
|
| 297 |
+
!pip install -q transformers datasets gradio sentencepiece
|
| 298 |
+
!pip install -q sentence-transformers gradio
|
| 299 |
+
|
| 300 |
+
"""song list & lyrics dataset setup"""
|
| 301 |
+
|
| 302 |
+
import pandas as pd
|
| 303 |
+
|
| 304 |
+
song_db = pd.read_parquet("hf://datasets/johanf/taylor-swift/data/train-00000-of-00001.parquet")
|
| 305 |
+
# Keep only needed columns and drop rows with missing values
|
| 306 |
+
song_db = song_db[["lyrics", "title"]].dropna().drop_duplicates()
|
| 307 |
+
|
| 308 |
+
# Optional: remove extra whitespace
|
| 309 |
+
song_db["lyrics"] = song_db["lyrics"].str.strip()
|
| 310 |
+
song_db["title"] = song_db["title"].str.strip()
|
| 311 |
+
|
| 312 |
+
# Reset index for convenience
|
| 313 |
+
song_db = song_db.reset_index(drop=True)
|
| 314 |
+
|
| 315 |
+
# Preview
|
| 316 |
+
song_db.head()
|
| 317 |
+
|
| 318 |
+
from sentence_transformers import SentenceTransformer
|
| 319 |
+
|
| 320 |
+
# Load a pre-trained model for semantic similarity
|
| 321 |
+
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
|
| 322 |
+
|
| 323 |
+
# Compute embeddings for lyrics (takes ~10–20 sec)
|
| 324 |
+
lyrics_list = song_db["lyrics"].tolist()
|
| 325 |
+
lyrics_embeddings = embedding_model.encode(lyrics_list, show_progress_bar=True)
|
| 326 |
+
|
| 327 |
+
"""Function that finds the best-fitting song
|
| 328 |
+
The find_matching_song function uses a semantic similarity approach to match user input with Taylor Swift lyrics. First, it transforms the user's sentence into a numerical vector using a pre-trained sentence embedding model (all-MiniLM-L6-v2). Then, it compares this vector to pre-computed embeddings of all song lyrics using cosine similarity, which measures how close the meanings are in semantic space. Finally, it ranks the songs by similarity and returns the top matches, showing their titles, a short lyric snippet, and a similarity score.
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
"""
|
| 334 |
+
|
| 335 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
| 336 |
+
import numpy as np
|
| 337 |
+
|
| 338 |
+
def find_matching_song(user_input, top_k=3):
|
| 339 |
+
# Encode the user input
|
| 340 |
+
user_embedding = embedding_model.encode([user_input])
|
| 341 |
+
|
| 342 |
+
# Compute cosine similarity with all lyrics
|
| 343 |
+
similarities = cosine_similarity(user_embedding, lyrics_embeddings)[0]
|
| 344 |
+
|
| 345 |
+
# Find top matches
|
| 346 |
+
top_indices = np.argsort(similarities)[-top_k:][::-1]
|
| 347 |
+
|
| 348 |
+
# Build output list
|
| 349 |
+
results = []
|
| 350 |
+
for idx in top_indices:
|
| 351 |
+
title = song_db.loc[idx, "title"]
|
| 352 |
+
lyric_snippet = song_db.loc[idx, "lyrics"][:200].replace("\n", " ") + "..."
|
| 353 |
+
score = similarities[idx]
|
| 354 |
+
results.append(f"🎵 **{title}** — _match: {score:.2f}_ \n`{lyric_snippet}`")
|
| 355 |
+
|
| 356 |
+
return "\n\n".join(results)
|
| 357 |
+
|
| 358 |
+
"""the app"""
|
| 359 |
+
|
| 360 |
+
import gradio as gr
|
| 361 |
+
|
| 362 |
+
demo = gr.Interface(
|
| 363 |
+
fn=find_matching_song,
|
| 364 |
+
inputs=gr.Textbox(
|
| 365 |
+
lines=2,
|
| 366 |
+
placeholder="What's on your mind? (e.g. I had ice cream today!)",
|
| 367 |
+
label="Enter your thought or mood"
|
| 368 |
+
),
|
| 369 |
+
outputs=gr.Markdown(),
|
| 370 |
+
title="🎤 Taylor Swift Lyric Matcher",
|
| 371 |
+
description="Enter a sentence, and get the Taylor Swift song with lyrics most semantically similar.",
|
| 372 |
+
examples=[
|
| 373 |
+
["I just broke up with my boyfriend"],
|
| 374 |
+
["I feel nostalgic about my childhood"],
|
| 375 |
+
["I got revenge on someone"],
|
| 376 |
+
["I'm in love and the world feels perfect"],
|
| 377 |
+
["I had ice cream today!"]
|
| 378 |
+
]
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
#app launch
|
| 382 |
+
demo.launch(share=True)
|