oddadmix's picture
Update app.py
a5f3455 verified
import spaces
import gradio as gr
from transformers import AutoProcessor, Gemma3nForConditionalGeneration
import torch
import os
# Global variables for model and processor
model = None
processor = None
def load_model():
"""Load the model and processor once at startup"""
global model, processor
print("Loading model...")
model_id = "oddadmix/MasriSwitch-Gemma3n-Transcriber-v1"
model = Gemma3nForConditionalGeneration.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
).eval()
processor = AutoProcessor.from_pretrained(model_id)
print("Model loaded successfully!")
@spaces.GPU
def transcribe_audio(audio_path, max_tokens=128):
"""Transcribe audio file using the loaded model"""
if model is None or processor is None:
return "Error: Model not loaded"
if audio_path is None:
return "Please upload or record an audio file"
try:
messages = [
{
"role": "system",
"content": [
{
"type": "text",
"text": "You are an assistant that transcribes speech accurately.",
}
],
},
{
"role": "user",
"content": [
{"type": "audio", "url": audio_path},
{"type": "text", "text": "Please transcribe this audio."}
]
}
]
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(model.device)
input_len = inputs["input_ids"].shape[-1]
# Generate transcription
with torch.inference_mode():
generation = model.generate(
**inputs,
max_new_tokens=max_tokens,
do_sample=False
)
generation = generation[0][input_len:]
response = processor.decode(generation, skip_special_tokens=True)
return response
except Exception as e:
return f"Error during transcription: {str(e)}"
# Load model at startup
load_model()
# Create Gradio interface
with gr.Blocks(title="Egyptian Code Switching Audio Transcription") as demo:
gr.Markdown(
"""
# πŸŽ™οΈ Egyptian Code Switching Audio Transcription
Upload an audio file or record your voice to get an automatic transcription.
Specialized for Egyptian Arabic with English code-switching.
"""
)
with gr.Row():
with gr.Column():
audio_input = gr.Audio(
sources=["upload", "microphone"],
type="filepath",
label="Audio Input"
)
max_tokens_slider = gr.Slider(
minimum=32,
maximum=512,
value=128,
step=32,
label="Max Output Tokens"
)
transcribe_btn = gr.Button("Transcribe", variant="primary")
with gr.Column():
output_text = gr.Textbox(
label="Transcription",
placeholder="Your transcription will appear here...",
lines=10,
rtl=True
)
gr.Markdown(
"""
### Tips:
- For best results, use clear audio with minimal background noise
- The model specializes in Egyptian Arabic with English code-switching
- Recording length should be reasonable (under 30 seconds recommended)
"""
)
# Set up the transcription action
transcribe_btn.click(
fn=transcribe_audio,
inputs=[audio_input, max_tokens_slider],
outputs=output_text
)
# Also allow transcription on audio upload/record
audio_input.change(
fn=transcribe_audio,
inputs=[audio_input, max_tokens_slider],
outputs=output_text
)
# Launch the app
if __name__ == "__main__":
demo.launch()