Spaces:
Paused
Paused
| import gradio as gr | |
| import torch | |
| from transformers import GPT2LMHeadModel, GPT2Tokenizer | |
| import librosa | |
| import numpy as np | |
| # -------- Load Model (small) -------- | |
| MODEL_NAME = "distilgpt2" | |
| tokenizer = GPT2Tokenizer.from_pretrained(MODEL_NAME) | |
| model = GPT2LMHeadModel.from_pretrained(MODEL_NAME) | |
| # -------- Session State -------- | |
| class SessionState: | |
| def __init__(self): | |
| self.tempo = None | |
| self.energy = None | |
| self.lyrics = [] | |
| state = SessionState() | |
| # -------- Beat Analysis (lightweight) -------- | |
| def analyze_beat(audio): | |
| y, sr = librosa.load(audio, sr=16000, mono=True) | |
| tempo, _ = librosa.beat.beat_track(y=y, sr=sr) | |
| energy = float(np.mean(np.abs(y))) | |
| return int(tempo), round(energy, 3) | |
| # -------- Generate Lyrics -------- | |
| def generate_lines(mood, lines=4, regenerate=False): | |
| global state | |
| if state.tempo is None: | |
| return "Upload a beat first." | |
| # Remove last 2 lines if regenerating | |
| if regenerate and len(state.lyrics) >= 2: | |
| state.lyrics = state.lyrics[:-2] | |
| context = "\n".join(state.lyrics) | |
| prompt = ( | |
| f"{context}\n" | |
| f"Rap lyrics for a {mood} beat at {state.tempo} BPM " | |
| f"with energy {state.energy}:\n" | |
| ) | |
| inputs = tokenizer.encode(prompt, return_tensors="pt") | |
| output = model.generate( | |
| inputs, | |
| max_length=inputs.shape[1] + lines * 12, | |
| do_sample=True, | |
| temperature=0.9, | |
| top_p=0.95, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| text = tokenizer.decode(output[0], skip_special_tokens=True) | |
| new_part = text.replace(prompt, "").strip().split("\n") | |
| clean_lines = [l.strip() for l in new_part if l.strip()][:lines] | |
| state.lyrics.extend(clean_lines) | |
| return "\n".join(state.lyrics) | |
| # -------- Upload Handler -------- | |
| def handle_upload(audio): | |
| global state | |
| tempo, energy = analyze_beat(audio) | |
| state.tempo = tempo | |
| state.energy = energy | |
| state.lyrics = [] | |
| return f"Beat analyzed: {tempo} BPM | Energy: {energy}" | |
| # -------- UI -------- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## 🎵 Beat-to-Lyrics Generator (Free CPU Optimized)") | |
| audio_input = gr.Audio(type="filepath") | |
| upload_btn = gr.Button("Analyze Beat") | |
| beat_info = gr.Textbox(label="Beat Info") | |
| mood = gr.Dropdown( | |
| ["Chill", "Hype", "Trap", "Lo-fi", "Boom-bap"], | |
| value="Chill", | |
| label="Mood" | |
| ) | |
| lyrics_output = gr.Textbox(lines=20, label="Lyrics") | |
| generate_btn = gr.Button("Generate Lines") | |
| regenerate_btn = gr.Button("Regenerate Last 2 Lines") | |
| upload_btn.click(handle_upload, audio_input, beat_info) | |
| generate_btn.click(lambda m: generate_lines(m, 6, False), mood, lyrics_output) | |
| regenerate_btn.click(lambda m: generate_lines(m, 4, True), mood, lyrics_output) | |
| demo.launch() | |