import torch import torch.nn as nn import librosa import numpy as np import gradio as gr from transformers import XLMRobertaModel, XLMRobertaTokenizer, WavLMModel from huggingface_hub import hf_hub_download import yt_dlp import os import re import lyricsgenius import syncedlyrics from youtubesearchpython import VideosSearch import warnings warnings.filterwarnings('ignore') # -- CONFIGURATION -- REPO_ID = "anggars/neural-mathrock" DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") TARGET_COLS = ['mbti', 'emotion', 'vibe', 'intensity', 'tempo'] GENIUS_TOKEN = os.environ.get("GENIUS_TOKEN", "z2XGBWXalGUtAdC1qxxXBxUnK1ZuoHPkCu5eP9q-fed-DW1uCJ3NSFpHemk3Unmg") print("Fetching model weights...") model_path = hf_hub_download(repo_id=REPO_ID, filename="model.pt") ckpt = torch.load(model_path, map_location=DEVICE, weights_only=False) le_mbti, le_emotion, le_vibe, le_intensity, le_tempo = ckpt['le_mbti'], ckpt['le_emotion'], ckpt['le_vibe'], ckpt['le_intensity'], ckpt['le_tempo'] # -- ARCHITECTURE -- class HybridMultimodalModel(nn.Module): def __init__(self): super().__init__() self.text_model = XLMRobertaModel.from_pretrained('anggars/xlm-mbti') self.audio_model = WavLMModel.from_pretrained('microsoft/wavlm-base') self.audio_proj = nn.Linear(768, 256) self.text_gate = nn.Sequential(nn.Linear(768, 768), nn.Sigmoid()) self.audio_gate = nn.Sequential(nn.Linear(256, 256), nn.Sigmoid()) self.fusion = nn.Sequential( nn.Linear(1024, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(0.4), ) self.head_mbti = nn.Linear(512, len(le_mbti.classes_)) self.head_emotion = nn.Linear(512, len(le_emotion.classes_)) self.head_vibe = nn.Linear(512, len(le_vibe.classes_)) self.head_intensity = nn.Linear(512, len(le_intensity.classes_)) self.head_tempo = nn.Linear(512, len(le_tempo.classes_)) def forward(self, input_ids, attention_mask, audio_values, text_missing=False): text_out = self.text_model(input_ids=input_ids, attention_mask=attention_mask) text_feat = text_out.pooler_output # Jika instrumental, matikan representasi teks if text_missing: text_feat = torch.zeros_like(text_feat) audio_out = self.audio_model(audio_values).last_hidden_state audio_feat = self.audio_proj(audio_out.mean(dim=1)) # Gating mechanism gated_text = text_feat * self.text_gate(text_feat) gated_audio = audio_feat * self.audio_gate(audio_feat) # Fusion fused = self.fusion(torch.cat([gated_text, gated_audio], dim=-1)) return {col: getattr(self, f'head_{col}')(fused) for col in TARGET_COLS} model = HybridMultimodalModel().to(DEVICE) model.load_state_dict(ckpt['model_state'], strict=False) model.eval() tokenizer = XLMRobertaTokenizer.from_pretrained('anggars/xlm-mbti') # -- UTILITIES -- def search_and_fetch(query): if not query or query.strip() == "": return None, "" try: search = VideosSearch(query, limit=1) res = search.result() if not res['result']: return None, "No results." video_url = res['result'][0]['link'] temp_fn = "temp_audio_file" ydl_opts = { 'format': 'bestaudio/best', 'outtmpl': f'{temp_fn}.%(ext)s', 'postprocessors': [{'key': 'FFmpegExtractAudio', 'preferredcodec': 'wav', 'preferredquality': '192'}], 'quiet': True, } with yt_dlp.YoutubeDL(ydl_opts) as ydl: ydl.download([video_url]) lyrics = "" try: genius = lyricsgenius.Genius(GENIUS_TOKEN, verbose=False) gsong = genius.search_song(query) if gsong: lyrics = re.sub(r'\[.*?\]', '', gsong.lyrics).strip() else: slrc = syncedlyrics.search(query) if slrc: lyrics = re.sub(r'\[\d{2}:\d{2}\.\d{2}\]', '', slrc).strip() except: pass return f"{temp_fn}.wav", lyrics except Exception as e: return None, str(e) def analyze_track(audio_path, lyrics_input): if not audio_path: return [{"Error": "No audio"}] * 5 try: is_inst = not lyrics_input or str(lyrics_input).strip() == "" text = str(lyrics_input).strip() if not is_inst else "[INSTRUMENTAL]" enc = tokenizer(text, truncation=True, padding='max_length', max_length=128, return_tensors='pt').to(DEVICE) wav, sr = librosa.load(audio_path, sr=16000) tempo_bpm, _ = librosa.beat.beat_track(y=wav, sr=sr) chunk_len = 16000 * 15 chunks = [wav[i:i + chunk_len] for i in range(0, len(wav), chunk_len) if len(wav[i:i+chunk_len]) >= 16000] all_logits = {col: [] for col in TARGET_COLS} with torch.no_grad(): for chunk in chunks: if len(chunk) < chunk_len: chunk = np.pad(chunk, (0, chunk_len - len(chunk))) audio_t = torch.tensor(chunk, dtype=torch.float32).unsqueeze(0).to(DEVICE) out = model(enc['input_ids'], enc['attention_mask'], audio_t, text_missing=is_inst) for col in TARGET_COLS: all_logits[col].append(out[col][0]) final_results = [] encoders = {'mbti': le_mbti, 'emotion': le_emotion, 'vibe': le_vibe, 'intensity': le_intensity, 'tempo': le_tempo} for col in TARGET_COLS: logits_stack = torch.stack(all_logits[col]) # Pake mean tanpa weight manual biar model audio kerja murni final_logits = logits_stack.mean(dim=0) / 0.7 if col == 'tempo': t_cls = list(le_tempo.classes_) try: if tempo_bpm > 125: final_logits[t_cls.index('Fast')] += 3.0 except: pass probs = torch.nn.functional.softmax(final_logits, dim=0).cpu().numpy() classes = encoders[col].classes_ res_dict = {str(classes[i]): float(probs[i]) for i in range(len(classes))} final_results.append(dict(sorted(res_dict.items(), key=lambda x: x[1], reverse=True)[:3])) return final_results except Exception as e: return [{"Error": str(e)}] * 5 # -- INTERFACE -- with gr.Blocks(theme=gr.themes.Monochrome()) as demo: gr.Markdown("# Neural Math Rock Multimodal Analysis") gr.Markdown("Identify personality and emotional states from music audio and lyrics.") with gr.Row(): with gr.Column(): search_box = gr.Textbox(label="YouTube Search (Artist - Song Title)", placeholder="Enter song name...") fetch_btn = gr.Button("FETCH AUDIO AND LYRICS", variant="secondary") gr.HTML("