File size: 5,223 Bytes
46b55ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import re
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from .config import RANDOM_STATE, ASSETS_DIR
# -------------------------
# Prompt Construction
# -------------------------
def build_critic_prompt(genre, artist, title, lyrics, max_lyric_length=300):
"""Constructs the instruction-tuning prompt for the Critic persona."""
lyrics_snippet = lyrics[:max_lyric_length]
instruction = (
"You are a professional music critic. Provide specific feedback on how to improve "
"the lyrics based on the genre and artist style. \n"
"Formatting Rules: \n"
"1. Use plain text with clear line breaks.\n"
"2. Ensure all song titles and words have proper spacing."
)
context = (
f"Target Genre: {genre}\n"
f"Target Artist: {artist}\n"
f"Target Title: {title}\n\n"
f"Lyrics to Evaluate:\n{lyrics_snippet}"
)
return f"<start_of_turn>user\n{instruction}\n\n{context}<end_of_turn>\n<start_of_turn>model\n"
def build_revision_prompt(genre, artist, title, draft, critiques):
"""Constructs the prompt for the 'Revise' step of the refinement loop."""
instruction = (
"You are an expert songwriter. Revise the provided lyrics by incorporating "
"the specific feedback from the critic while maintaining the genre and artist style."
)
context = (
f"Genre: {genre}\n"
f"Artist Style: {artist}\n"
f"Title: {title}\n\n"
f"Current Draft:\n{draft}\n\n"
f"Critic Feedback:\n{critiques}"
)
return f"<start_of_turn>user\n{instruction}\n\n{context}<end_of_turn>\n<start_of_turn>model\n"
def build_inference_prompt(genre, artist, title):
"""Reconstructs the prompt format used during v1.0 training."""
instruction = "Generate lyrics for a song based on these details."
input_context = f"Genre: {genre}\nArtist: {artist}\nTitle: {title}"
return (
f"<start_of_turn>user\n{instruction}\n\n{input_context}<end_of_turn>\n"
f"<start_of_turn>model\n"
)
def format_prompt(row):
"""Converts a dataframe row into a structured Gemma control-token prompt."""
instruction = "Generate lyrics for a song based on these details."
input_context = f"Genre: {row['tag']}\nArtist: {row['artist']}\nTitle: {row['title']}"
response = row['lyrics']
return (
f"<start_of_turn>user\n{instruction}\n\n{input_context}<end_of_turn>\n"
f"<start_of_turn>model\n{response}<end_of_turn>"
)
# -------------------------
# Text Processing
# -------------------------
def format_lyrics(text):
"""Cleans up raw model output by enforcing structural newlines and spacing."""
# Add double newlines before section headers like [Verse], [Chorus]
text = re.sub(r'(\[.*?\])', r'\n\n\1\n', text)
# Add a newline when a capital letter follows a lowercase letter immediately
text = re.sub(r'([a-z])([A-Z])', r'\1\n\2', text)
return text.strip()
# -------------------------
# Dataset Management
# -------------------------
def format_critic_training_row(row):
"""Standardizes raw rows into the Critic instruction-tuning format."""
prompt = build_critic_prompt(row.tag, row.artist, row.title, row.lyrics)
target_output = (
f"Genre Fit: The {row.tag} style is well-maintained.\n"
f"Artist Style: Matches the {row.artist} aesthetic.\n"
f"Improvements: Consider refining the rhythmic flow in the second verse."
)
return f"{prompt}{target_output}<eos>"
def prepare_lyric_dataset(lyrics_filename, reviews_filename, songs_per_genre=200):
"""Loads, cleans, and balances the dataset while exporting EDA plots."""
from .viz import save_figure
lyrics_path = os.path.join("data", lyrics_filename)
reviews_path = os.path.join("data", reviews_filename)
print(f"Loading & Cleaning Raw Data...")
lyrics_df = pd.read_csv(lyrics_path, on_bad_lines='skip')
reviews_df = pd.read_csv(reviews_path)
lyrics_df = lyrics_df.dropna(subset=['lyrics', 'artist', 'tag'])
reviews_df = reviews_df.dropna(subset=['genre', 'artist'])
lyrics_clean = lyrics_df.drop_duplicates(subset="artist")[["artist", "lyrics", "title", "tag"]]
merged_df = reviews_df.merge(lyrics_clean, on="artist", how="left").dropna(subset=["lyrics", "tag"])
# --- Plot 1: Raw Distribution ("Before") ---
plt.figure(figsize=(10, 5))
top_raw = merged_df['tag'].value_counts().nlargest(10)
sns.barplot(x=top_raw.values, y=top_raw.index, hue=top_raw.index, palette='viridis', legend=False)
plt.title(f"Raw Genre Distribution (n={len(merged_df):,})")
save_figure("eda_1_raw_distribution.png")
# Class balancing logic
balanced_df = merged_df.groupby("tag", group_keys=False).apply(
lambda x: x.sample(min(len(x), songs_per_genre), random_state=RANDOM_STATE)
)
# --- Plot 2: Balanced Distribution ("After") ---
plt.figure(figsize=(10, 5))
sns.countplot(data=balanced_df, y='tag', hue='tag', palette='magma', legend=False)
plt.title(f"Balanced Genre Distribution (n={len(balanced_df):,})")
save_figure("eda_2_balanced_distribution.png")
return balanced_df |