Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import librosa | |
| import numpy as np | |
| import gradio as gr | |
| from transformers import XLMRobertaModel, XLMRobertaTokenizer, WavLMModel | |
| from huggingface_hub import hf_hub_download | |
| import yt_dlp | |
| import os | |
| import re | |
| import lyricsgenius | |
| import syncedlyrics | |
| from youtubesearchpython import VideosSearch | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| # -- CONFIGURATION -- | |
| REPO_ID = "anggars/neural-mathrock" | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| TARGET_COLS = ['mbti', 'emotion', 'vibe', 'intensity', 'tempo'] | |
| GENIUS_TOKEN = os.environ.get("GENIUS_TOKEN", "z2XGBWXalGUtAdC1qxxXBxUnK1ZuoHPkCu5eP9q-fed-DW1uCJ3NSFpHemk3Unmg") | |
| print("Fetching model weights...") | |
| model_path = hf_hub_download(repo_id=REPO_ID, filename="model.pt") | |
| ckpt = torch.load(model_path, map_location=DEVICE, weights_only=False) | |
| le_mbti, le_emotion, le_vibe, le_intensity, le_tempo = ckpt['le_mbti'], ckpt['le_emotion'], ckpt['le_vibe'], ckpt['le_intensity'], ckpt['le_tempo'] | |
| # -- ARCHITECTURE -- | |
| class HybridMultimodalModel(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.text_model = XLMRobertaModel.from_pretrained('anggars/xlm-mbti') | |
| self.audio_model = WavLMModel.from_pretrained('microsoft/wavlm-base') | |
| self.audio_proj = nn.Linear(768, 256) | |
| self.text_gate = nn.Sequential(nn.Linear(768, 768), nn.Sigmoid()) | |
| self.audio_gate = nn.Sequential(nn.Linear(256, 256), nn.Sigmoid()) | |
| self.fusion = nn.Sequential( | |
| nn.Linear(1024, 512), | |
| nn.BatchNorm1d(512), | |
| nn.ReLU(), | |
| nn.Dropout(0.4), | |
| ) | |
| self.head_mbti = nn.Linear(512, len(le_mbti.classes_)) | |
| self.head_emotion = nn.Linear(512, len(le_emotion.classes_)) | |
| self.head_vibe = nn.Linear(512, len(le_vibe.classes_)) | |
| self.head_intensity = nn.Linear(512, len(le_intensity.classes_)) | |
| self.head_tempo = nn.Linear(512, len(le_tempo.classes_)) | |
| def forward(self, input_ids, attention_mask, audio_values, text_missing=False): | |
| text_out = self.text_model(input_ids=input_ids, attention_mask=attention_mask) | |
| text_feat = text_out.pooler_output | |
| # Jika instrumental, matikan representasi teks | |
| if text_missing: | |
| text_feat = torch.zeros_like(text_feat) | |
| audio_out = self.audio_model(audio_values).last_hidden_state | |
| audio_feat = self.audio_proj(audio_out.mean(dim=1)) | |
| # Gating mechanism | |
| gated_text = text_feat * self.text_gate(text_feat) | |
| gated_audio = audio_feat * self.audio_gate(audio_feat) | |
| # Fusion | |
| fused = self.fusion(torch.cat([gated_text, gated_audio], dim=-1)) | |
| return {col: getattr(self, f'head_{col}')(fused) for col in TARGET_COLS} | |
| model = HybridMultimodalModel().to(DEVICE) | |
| model.load_state_dict(ckpt['model_state'], strict=False) | |
| model.eval() | |
| tokenizer = XLMRobertaTokenizer.from_pretrained('anggars/xlm-mbti') | |
| # -- UTILITIES -- | |
| def search_and_fetch(query): | |
| if not query or query.strip() == "": return None, "" | |
| try: | |
| search = VideosSearch(query, limit=1) | |
| res = search.result() | |
| if not res['result']: return None, "No results." | |
| video_url = res['result'][0]['link'] | |
| temp_fn = "temp_audio_file" | |
| ydl_opts = { | |
| 'format': 'bestaudio/best', | |
| 'outtmpl': f'{temp_fn}.%(ext)s', | |
| 'postprocessors': [{'key': 'FFmpegExtractAudio', 'preferredcodec': 'wav', 'preferredquality': '192'}], | |
| 'quiet': True, | |
| } | |
| with yt_dlp.YoutubeDL(ydl_opts) as ydl: | |
| ydl.download([video_url]) | |
| lyrics = "" | |
| try: | |
| genius = lyricsgenius.Genius(GENIUS_TOKEN, verbose=False) | |
| gsong = genius.search_song(query) | |
| if gsong: lyrics = re.sub(r'\[.*?\]', '', gsong.lyrics).strip() | |
| else: | |
| slrc = syncedlyrics.search(query) | |
| if slrc: lyrics = re.sub(r'\[\d{2}:\d{2}\.\d{2}\]', '', slrc).strip() | |
| except: pass | |
| return f"{temp_fn}.wav", lyrics | |
| except Exception as e: return None, str(e) | |
| def analyze_track(audio_path, lyrics_input): | |
| if not audio_path: return [{"Error": "No audio"}] * 5 | |
| try: | |
| is_inst = not lyrics_input or str(lyrics_input).strip() == "" | |
| text = str(lyrics_input).strip() if not is_inst else "[INSTRUMENTAL]" | |
| enc = tokenizer(text, truncation=True, padding='max_length', max_length=128, return_tensors='pt').to(DEVICE) | |
| wav, sr = librosa.load(audio_path, sr=16000) | |
| tempo_bpm, _ = librosa.beat.beat_track(y=wav, sr=sr) | |
| chunk_len = 16000 * 15 | |
| chunks = [wav[i:i + chunk_len] for i in range(0, len(wav), chunk_len) if len(wav[i:i+chunk_len]) >= 16000] | |
| all_logits = {col: [] for col in TARGET_COLS} | |
| with torch.no_grad(): | |
| for chunk in chunks: | |
| if len(chunk) < chunk_len: chunk = np.pad(chunk, (0, chunk_len - len(chunk))) | |
| audio_t = torch.tensor(chunk, dtype=torch.float32).unsqueeze(0).to(DEVICE) | |
| out = model(enc['input_ids'], enc['attention_mask'], audio_t, text_missing=is_inst) | |
| for col in TARGET_COLS: all_logits[col].append(out[col][0]) | |
| final_results = [] | |
| encoders = {'mbti': le_mbti, 'emotion': le_emotion, 'vibe': le_vibe, 'intensity': le_intensity, 'tempo': le_tempo} | |
| for col in TARGET_COLS: | |
| logits_stack = torch.stack(all_logits[col]) | |
| # Pake mean tanpa weight manual biar model audio kerja murni | |
| final_logits = logits_stack.mean(dim=0) / 0.7 | |
| if col == 'tempo': | |
| t_cls = list(le_tempo.classes_) | |
| try: | |
| if tempo_bpm > 125: final_logits[t_cls.index('Fast')] += 3.0 | |
| except: pass | |
| probs = torch.nn.functional.softmax(final_logits, dim=0).cpu().numpy() | |
| classes = encoders[col].classes_ | |
| res_dict = {str(classes[i]): float(probs[i]) for i in range(len(classes))} | |
| final_results.append(dict(sorted(res_dict.items(), key=lambda x: x[1], reverse=True)[:3])) | |
| return final_results | |
| except Exception as e: return [{"Error": str(e)}] * 5 | |
| # -- INTERFACE -- | |
| with gr.Blocks(theme=gr.themes.Monochrome()) as demo: | |
| gr.Markdown("# Neural Math Rock Multimodal Analysis") | |
| gr.Markdown("Identify personality and emotional states from music audio and lyrics.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| search_box = gr.Textbox(label="YouTube Search (Artist - Song Title)", placeholder="Enter song name...") | |
| fetch_btn = gr.Button("FETCH AUDIO AND LYRICS", variant="secondary") | |
| gr.HTML("<hr>") | |
| audio_box = gr.Audio(type="filepath", label="Audio Source") | |
| lyrics_box = gr.Textbox(lines=6, label="Lyrics Source", placeholder="Lyrics...") | |
| run_btn = gr.Button("RUN ANALYSIS", variant="primary") | |
| with gr.Column(): | |
| res_mbti = gr.Label(label="Personality (MBTI)") | |
| res_emo = gr.Label(label="Emotional State") | |
| res_vibe = gr.Label(label="Acoustic Vibe") | |
| res_int = gr.Label(label="Intensity Level") | |
| res_tmp = gr.Label(label="Tempo Classification") | |
| fetch_btn.click(fn=search_and_fetch, inputs=[search_box], outputs=[audio_box, lyrics_box]) | |
| run_btn.click(fn=analyze_track, inputs=[audio_box, lyrics_box], outputs=[res_mbti, res_emo, res_vibe, res_int, res_tmp]) | |
| if __name__ == "__main__": | |
| demo.launch() |