neural-mathrock / app.py
anggars's picture
Update app.py
beaad7b verified
import torch
import torch.nn as nn
import librosa
import numpy as np
import gradio as gr
from transformers import XLMRobertaModel, XLMRobertaTokenizer, WavLMModel
from huggingface_hub import hf_hub_download
import yt_dlp
import os
import re
import lyricsgenius
import syncedlyrics
from youtubesearchpython import VideosSearch
import warnings
warnings.filterwarnings('ignore')
# -- CONFIGURATION --
REPO_ID = "anggars/neural-mathrock"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
TARGET_COLS = ['mbti', 'emotion', 'vibe', 'intensity', 'tempo']
GENIUS_TOKEN = os.environ.get("GENIUS_TOKEN", "z2XGBWXalGUtAdC1qxxXBxUnK1ZuoHPkCu5eP9q-fed-DW1uCJ3NSFpHemk3Unmg")
print("Fetching model weights...")
model_path = hf_hub_download(repo_id=REPO_ID, filename="model.pt")
ckpt = torch.load(model_path, map_location=DEVICE, weights_only=False)
le_mbti, le_emotion, le_vibe, le_intensity, le_tempo = ckpt['le_mbti'], ckpt['le_emotion'], ckpt['le_vibe'], ckpt['le_intensity'], ckpt['le_tempo']
# -- ARCHITECTURE --
class HybridMultimodalModel(nn.Module):
def __init__(self):
super().__init__()
self.text_model = XLMRobertaModel.from_pretrained('anggars/xlm-mbti')
self.audio_model = WavLMModel.from_pretrained('microsoft/wavlm-base')
self.audio_proj = nn.Linear(768, 256)
self.text_gate = nn.Sequential(nn.Linear(768, 768), nn.Sigmoid())
self.audio_gate = nn.Sequential(nn.Linear(256, 256), nn.Sigmoid())
self.fusion = nn.Sequential(
nn.Linear(1024, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Dropout(0.4),
)
self.head_mbti = nn.Linear(512, len(le_mbti.classes_))
self.head_emotion = nn.Linear(512, len(le_emotion.classes_))
self.head_vibe = nn.Linear(512, len(le_vibe.classes_))
self.head_intensity = nn.Linear(512, len(le_intensity.classes_))
self.head_tempo = nn.Linear(512, len(le_tempo.classes_))
def forward(self, input_ids, attention_mask, audio_values, text_missing=False):
text_out = self.text_model(input_ids=input_ids, attention_mask=attention_mask)
text_feat = text_out.pooler_output
# Jika instrumental, matikan representasi teks
if text_missing:
text_feat = torch.zeros_like(text_feat)
audio_out = self.audio_model(audio_values).last_hidden_state
audio_feat = self.audio_proj(audio_out.mean(dim=1))
# Gating mechanism
gated_text = text_feat * self.text_gate(text_feat)
gated_audio = audio_feat * self.audio_gate(audio_feat)
# Fusion
fused = self.fusion(torch.cat([gated_text, gated_audio], dim=-1))
return {col: getattr(self, f'head_{col}')(fused) for col in TARGET_COLS}
model = HybridMultimodalModel().to(DEVICE)
model.load_state_dict(ckpt['model_state'], strict=False)
model.eval()
tokenizer = XLMRobertaTokenizer.from_pretrained('anggars/xlm-mbti')
# -- UTILITIES --
def search_and_fetch(query):
if not query or query.strip() == "": return None, ""
try:
search = VideosSearch(query, limit=1)
res = search.result()
if not res['result']: return None, "No results."
video_url = res['result'][0]['link']
temp_fn = "temp_audio_file"
ydl_opts = {
'format': 'bestaudio/best',
'outtmpl': f'{temp_fn}.%(ext)s',
'postprocessors': [{'key': 'FFmpegExtractAudio', 'preferredcodec': 'wav', 'preferredquality': '192'}],
'quiet': True,
}
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
ydl.download([video_url])
lyrics = ""
try:
genius = lyricsgenius.Genius(GENIUS_TOKEN, verbose=False)
gsong = genius.search_song(query)
if gsong: lyrics = re.sub(r'\[.*?\]', '', gsong.lyrics).strip()
else:
slrc = syncedlyrics.search(query)
if slrc: lyrics = re.sub(r'\[\d{2}:\d{2}\.\d{2}\]', '', slrc).strip()
except: pass
return f"{temp_fn}.wav", lyrics
except Exception as e: return None, str(e)
def analyze_track(audio_path, lyrics_input):
if not audio_path: return [{"Error": "No audio"}] * 5
try:
is_inst = not lyrics_input or str(lyrics_input).strip() == ""
text = str(lyrics_input).strip() if not is_inst else "[INSTRUMENTAL]"
enc = tokenizer(text, truncation=True, padding='max_length', max_length=128, return_tensors='pt').to(DEVICE)
wav, sr = librosa.load(audio_path, sr=16000)
tempo_bpm, _ = librosa.beat.beat_track(y=wav, sr=sr)
chunk_len = 16000 * 15
chunks = [wav[i:i + chunk_len] for i in range(0, len(wav), chunk_len) if len(wav[i:i+chunk_len]) >= 16000]
all_logits = {col: [] for col in TARGET_COLS}
with torch.no_grad():
for chunk in chunks:
if len(chunk) < chunk_len: chunk = np.pad(chunk, (0, chunk_len - len(chunk)))
audio_t = torch.tensor(chunk, dtype=torch.float32).unsqueeze(0).to(DEVICE)
out = model(enc['input_ids'], enc['attention_mask'], audio_t, text_missing=is_inst)
for col in TARGET_COLS: all_logits[col].append(out[col][0])
final_results = []
encoders = {'mbti': le_mbti, 'emotion': le_emotion, 'vibe': le_vibe, 'intensity': le_intensity, 'tempo': le_tempo}
for col in TARGET_COLS:
logits_stack = torch.stack(all_logits[col])
# Pake mean tanpa weight manual biar model audio kerja murni
final_logits = logits_stack.mean(dim=0) / 0.7
if col == 'tempo':
t_cls = list(le_tempo.classes_)
try:
if tempo_bpm > 125: final_logits[t_cls.index('Fast')] += 3.0
except: pass
probs = torch.nn.functional.softmax(final_logits, dim=0).cpu().numpy()
classes = encoders[col].classes_
res_dict = {str(classes[i]): float(probs[i]) for i in range(len(classes))}
final_results.append(dict(sorted(res_dict.items(), key=lambda x: x[1], reverse=True)[:3]))
return final_results
except Exception as e: return [{"Error": str(e)}] * 5
# -- INTERFACE --
with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
gr.Markdown("# Neural Math Rock Multimodal Analysis")
gr.Markdown("Identify personality and emotional states from music audio and lyrics.")
with gr.Row():
with gr.Column():
search_box = gr.Textbox(label="YouTube Search (Artist - Song Title)", placeholder="Enter song name...")
fetch_btn = gr.Button("FETCH AUDIO AND LYRICS", variant="secondary")
gr.HTML("<hr>")
audio_box = gr.Audio(type="filepath", label="Audio Source")
lyrics_box = gr.Textbox(lines=6, label="Lyrics Source", placeholder="Lyrics...")
run_btn = gr.Button("RUN ANALYSIS", variant="primary")
with gr.Column():
res_mbti = gr.Label(label="Personality (MBTI)")
res_emo = gr.Label(label="Emotional State")
res_vibe = gr.Label(label="Acoustic Vibe")
res_int = gr.Label(label="Intensity Level")
res_tmp = gr.Label(label="Tempo Classification")
fetch_btn.click(fn=search_and_fetch, inputs=[search_box], outputs=[audio_box, lyrics_box])
run_btn.click(fn=analyze_track, inputs=[audio_box, lyrics_box], outputs=[res_mbti, res_emo, res_vibe, res_int, res_tmp])
if __name__ == "__main__":
demo.launch()