| import gradio as gr |
| import plotly.express as px |
| import pandas as pd |
| import logging |
| import whisper |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| import pandas as pd |
| from torch.nn.functional import silu |
| from torch.nn.functional import softplus |
| from einops import rearrange, repeat, einsum |
| from transformers import AutoTokenizer, AutoModel |
| from torch import Tensor |
| from einops import rearrange |
|
|
| from model import Mamba |
|
|
| logging.basicConfig(level=logging.INFO) |
|
|
| def plotly_plot_text(text): |
| data = pd.DataFrame() |
| data['Emotion'] = ['π anger', 'π€’ disgust', 'π¨ fear', 'π joy/happiness', 'π neutral', 'π’ sadness', 'π² surprise/enthusiasm'] |
| data['Probability'] = model.predict_proba([text])[0].tolist() |
| p = px.bar(data, x='Emotion', y='Probability', color="Probability") |
| return ( |
| p, |
| f"π£οΈ Transcription:\n{text}", |
| f"## π Dominant Emotion: {data['Emotion'].values[np.argmax(np.array(data['Probability']))]}" |
| ) |
|
|
| def transcribe_audio(audio_path): |
| whisper_model = whisper.load_model("base") |
| try: |
| result = whisper_model.transcribe(audio_path, fp16=False) |
| return result.get('text', '') |
| except Exception as e: |
| logging.error(f"Transcription failed: {e}") |
| return "" |
|
|
| def plotly_plot_audio(audio_path): |
| data = pd.DataFrame() |
| data['Emotion'] = ['π anger', 'π€’ disgust', 'π¨ fear', 'π joy/happiness', 'π neutral', 'π’ sadness', 'π² surprise/enthusiasm'] |
| try: |
| text = transcribe_audio(audio_path) |
| data['Probability'] = model.predict_proba([text])[0].tolist() if text.strip() else [0.0] * data.shape[0] |
| p = px.bar(data, x='Emotion', y='Probability', color="Probability") |
| return ( |
| p, |
| f"## βοΈ Dominant Emotion: {data['Emotion'].values[np.argmax(np.array(data['Probability']))]}" |
| ) |
|
|
| except Exception as e: |
| logging.error(f"Processing failed: {e}") |
| data['Probability'] = [0] * data.shape[0] |
| p = px.bar(data, x='Emotion', y='Probability', color="Probability") |
| return ( |
| p, |
| "β Error processing audio", |
| "β οΈ Processing Error" |
| ) |
| |
| def plotly_plot_audio(audio_path): |
| data = pd.DataFrame() |
| data['Emotion'] = ['π anger', 'π€’ disgust', 'π¨ fear', 'π joy/happiness', 'π neutral', 'π’ sadness', 'π² surprise/enthusiasm'] |
| try: |
| text = transcribe_audio(audio_path) |
| data['Probability'] = model.predict_proba([text])[0].tolist() if text.strip() else [0.0] * data.shape[0] |
| p = px.bar(data, x='Emotion', y='Probability', color="Probability") |
| return ( |
| p, |
| f"π€ Transcription:\n{text}", |
| f"## βοΈ Dominant Emotion: {data['Emotion'].values[np.argmax(np.array(data['Probability']))]}" |
| ) |
|
|
| except Exception as e: |
| logging.error(f"Processing failed: {e}") |
| data['Probability'] = [0] * data.shape[0] |
| p = px.bar(data, x='Emotion', y='Probability', color="Probability") |
| return ( |
| p, |
| "β Error processing audio", |
| "β οΈ Processing Error" |
| ) |
| |
| def create_demo_text(): |
| with gr.Blocks(theme='Nymbo/rounded-gradient', css=".gradio-container {background-color: #F0F8FF}", title="Emotion Detection") as demo: |
| gr.Markdown("# Text-based bilingual emotion recognition") |
|
|
| with gr.Row(): |
| text_input = gr.Textbox(label="Write Text") |
|
|
| with gr.Row(): |
| top_emotion = gr.Markdown("## βοΈ Dominant Emotion: Waiting for input ...", |
| elem_classes="dominant-emotion") |
|
|
| with gr.Row(): |
| text_plot = gr.Plot(label="Text Analysis") |
|
|
| text_input.change(fn=plotly_plot_text, inputs=text_input, outputs=[text_plot, top_emotion]) |
| return demo |
|
|
| def create_demo_audio(): |
| with gr.Blocks(theme='Nymbo/rounded-gradient', css=".gradio-container {background-color: #F0F8FF}", title="Emotion Detection") as demo: |
| gr.Markdown("# Text-based bilingual emotion recognition with audio transcription") |
|
|
| with gr.Row(): |
| audio_input = gr.Audio( |
| sources=["upload", "microphone"], |
| type="filepath", |
| label="Record or Upload Audio", |
| format="wav", |
| interactive=True |
| ) |
| with gr.Row(): |
| top_emotion = gr.Markdown("## βοΈ Dominant Emotion: Waiting for input ...", |
| elem_classes="dominant-emotion") |
|
|
| with gr.Row(): |
| text_plot = gr.Plot(label="Text Analysis") |
|
|
| transcription = gr.Textbox( |
| label="π Transcription Results", |
| placeholder="Transcribed text will appear here...", |
| lines=3, |
| max_lines=6 |
| ) |
| audio_input.change(fn=plotly_plot_audio, inputs=audio_input, outputs=[text_plot, transcription, top_emotion]) |
| return demo |
|
|
| def create_demo(): |
| text = create_demo_text() |
| audio = create_demo_audio() |
| demo = gr.TabbedInterface( |
| [text, audio], |
| ["Text Prediction", "Transcribed Audio Prediction"], |
| ) |
| return demo |
| |
|
|
| if __name__ == "__main__": |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model = Mamba(num_layers = 2, d_input = 1024, d_model = 512, num_classes=7, model_name='jina', pooling=None).to(device) |
| checkpoint = torch.load("Mamba_jina_checkpoint.pth", map_location=torch.device('cpu')) |
| model.load_state_dict(checkpoint['model_state_dict']) |
|
|
| demo = create_demo() |
| demo.launch() |