File size: 5,223 Bytes
46b55ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import re
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from .config import RANDOM_STATE, ASSETS_DIR

# -------------------------
# Prompt Construction
# -------------------------

def build_critic_prompt(genre, artist, title, lyrics, max_lyric_length=300):
    """Constructs the instruction-tuning prompt for the Critic persona."""
    lyrics_snippet = lyrics[:max_lyric_length]
    
    instruction = (
        "You are a professional music critic. Provide specific feedback on how to improve "
        "the lyrics based on the genre and artist style. \n"
        "Formatting Rules: \n"
        "1. Use plain text with clear line breaks.\n"
        "2. Ensure all song titles and words have proper spacing."
    )

    context = (
        f"Target Genre: {genre}\n"
        f"Target Artist: {artist}\n"
        f"Target Title: {title}\n\n"
        f"Lyrics to Evaluate:\n{lyrics_snippet}"
    )

    return f"<start_of_turn>user\n{instruction}\n\n{context}<end_of_turn>\n<start_of_turn>model\n"

def build_revision_prompt(genre, artist, title, draft, critiques):
    """Constructs the prompt for the 'Revise' step of the refinement loop."""
    instruction = (
        "You are an expert songwriter. Revise the provided lyrics by incorporating "
        "the specific feedback from the critic while maintaining the genre and artist style."
    )

    context = (
        f"Genre: {genre}\n"
        f"Artist Style: {artist}\n"
        f"Title: {title}\n\n"
        f"Current Draft:\n{draft}\n\n"
        f"Critic Feedback:\n{critiques}"
    )

    return f"<start_of_turn>user\n{instruction}\n\n{context}<end_of_turn>\n<start_of_turn>model\n"

def build_inference_prompt(genre, artist, title):
    """Reconstructs the prompt format used during v1.0 training."""
    instruction = "Generate lyrics for a song based on these details."
    input_context = f"Genre: {genre}\nArtist: {artist}\nTitle: {title}"

    return (
        f"<start_of_turn>user\n{instruction}\n\n{input_context}<end_of_turn>\n"
        f"<start_of_turn>model\n"
    )

def format_prompt(row):
    """Converts a dataframe row into a structured Gemma control-token prompt."""
    instruction = "Generate lyrics for a song based on these details."
    input_context = f"Genre: {row['tag']}\nArtist: {row['artist']}\nTitle: {row['title']}"
    response = row['lyrics']

    return (
        f"<start_of_turn>user\n{instruction}\n\n{input_context}<end_of_turn>\n"
        f"<start_of_turn>model\n{response}<end_of_turn>"
    )

# -------------------------
# Text Processing
# -------------------------

def format_lyrics(text):
    """Cleans up raw model output by enforcing structural newlines and spacing."""
    # Add double newlines before section headers like [Verse], [Chorus]
    text = re.sub(r'(\[.*?\])', r'\n\n\1\n', text)

    # Add a newline when a capital letter follows a lowercase letter immediately
    text = re.sub(r'([a-z])([A-Z])', r'\1\n\2', text)
    return text.strip()

# -------------------------
# Dataset Management
# -------------------------

def format_critic_training_row(row):
    """Standardizes raw rows into the Critic instruction-tuning format."""
    prompt = build_critic_prompt(row.tag, row.artist, row.title, row.lyrics)
    
    target_output = (
        f"Genre Fit: The {row.tag} style is well-maintained.\n"
        f"Artist Style: Matches the {row.artist} aesthetic.\n"
        f"Improvements: Consider refining the rhythmic flow in the second verse."
    )

    return f"{prompt}{target_output}<eos>"

def prepare_lyric_dataset(lyrics_filename, reviews_filename, songs_per_genre=200):
    """Loads, cleans, and balances the dataset while exporting EDA plots."""
    from .viz import save_figure
    
    lyrics_path = os.path.join("data", lyrics_filename)
    reviews_path = os.path.join("data", reviews_filename)

    print(f"Loading & Cleaning Raw Data...")
    
    lyrics_df = pd.read_csv(lyrics_path, on_bad_lines='skip')
    reviews_df = pd.read_csv(reviews_path)

    lyrics_df = lyrics_df.dropna(subset=['lyrics', 'artist', 'tag'])
    reviews_df = reviews_df.dropna(subset=['genre', 'artist'])

    lyrics_clean = lyrics_df.drop_duplicates(subset="artist")[["artist", "lyrics", "title", "tag"]]
    merged_df = reviews_df.merge(lyrics_clean, on="artist", how="left").dropna(subset=["lyrics", "tag"])

    # --- Plot 1: Raw Distribution ("Before") ---
    plt.figure(figsize=(10, 5))
    top_raw = merged_df['tag'].value_counts().nlargest(10)
    sns.barplot(x=top_raw.values, y=top_raw.index, hue=top_raw.index, palette='viridis', legend=False)
    plt.title(f"Raw Genre Distribution (n={len(merged_df):,})")
    save_figure("eda_1_raw_distribution.png")

    # Class balancing logic
    balanced_df = merged_df.groupby("tag", group_keys=False).apply(
        lambda x: x.sample(min(len(x), songs_per_genre), random_state=RANDOM_STATE)
    )

    # --- Plot 2: Balanced Distribution ("After") ---
    plt.figure(figsize=(10, 5))
    sns.countplot(data=balanced_df, y='tag', hue='tag', palette='magma', legend=False)
    plt.title(f"Balanced Genre Distribution (n={len(balanced_df):,})")
    save_figure("eda_2_balanced_distribution.png")

    return balanced_df