File size: 4,228 Bytes
921fc96
7d2baf1
672dce6
7d2baf1
 
 
 
 
 
 
58404fd
 
 
 
 
a5f3455
618ee0a
06bbda0
672dce6
20554b7
672dce6
 
 
58404fd
672dce6
58404fd
 
672dce6
921fc96
7d2baf1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
672dce6
 
 
7d2baf1
 
672dce6
 
 
 
 
 
 
7d2baf1
672dce6
7d2baf1
 
 
 
 
 
 
 
 
 
20554b7
7d2baf1
 
20554b7
7d2baf1
06bbda0
20554b7
7d2baf1
 
 
06bbda0
 
 
 
 
 
7d2baf1
06bbda0
 
 
 
 
 
7d2baf1
06bbda0
7d2baf1
06bbda0
 
 
 
 
 
7d2baf1
 
 
 
 
 
20554b7
06bbda0
7d2baf1
 
06bbda0
 
 
 
 
 
 
 
 
 
 
 
 
 
7d2baf1
 
 
672dce6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import spaces
import gradio as gr
from transformers import AutoProcessor, Gemma3nForConditionalGeneration
import torch
import os

# Global variables for model and processor
model = None
processor = None

def load_model():
    """Load the model and processor once at startup"""
    global model, processor
    
    print("Loading model...")
    model_id = "oddadmix/MasriSwitch-Gemma3n-Transcriber-v1"
    
    
    model = Gemma3nForConditionalGeneration.from_pretrained(
        model_id,
        device_map="auto",
        torch_dtype=torch.bfloat16,
    ).eval()
    
    processor = AutoProcessor.from_pretrained(model_id)
    
    print("Model loaded successfully!")

@spaces.GPU
def transcribe_audio(audio_path, max_tokens=128):
    """Transcribe audio file using the loaded model"""
    if model is None or processor is None:
        return "Error: Model not loaded"
    
    if audio_path is None:
        return "Please upload or record an audio file"
    
    try:
        messages = [
            {
                "role": "system",
                "content": [
                    {
                        "type": "text",
                        "text": "You are an assistant that transcribes speech accurately.",
                    }
                ],
            },
            {
                "role": "user",
                "content": [
                    {"type": "audio", "url": audio_path},
                    {"type": "text", "text": "Please transcribe this audio."}
                ]
            }
        ]
        
        inputs = processor.apply_chat_template(
            messages,
            add_generation_prompt=True,
            tokenize=True,
            return_dict=True,
            return_tensors="pt",
        ).to(model.device)
        
        input_len = inputs["input_ids"].shape[-1]
        
        # Generate transcription
        with torch.inference_mode():
            generation = model.generate(
                **inputs,
                max_new_tokens=max_tokens,
                do_sample=False
            )
            generation = generation[0][input_len:]
        
        response = processor.decode(generation, skip_special_tokens=True)
        
        return response
        
    except Exception as e:
        return f"Error during transcription: {str(e)}"

# Load model at startup
load_model()

# Create Gradio interface
with gr.Blocks(title="Egyptian Code Switching Audio Transcription") as demo:
    gr.Markdown(
        """
        # ๐ŸŽ™๏ธ Egyptian Code Switching Audio Transcription
        
        Upload an audio file or record your voice to get an automatic transcription.
        Specialized for Egyptian Arabic with English code-switching.
        """
    )
    
    with gr.Row():
        with gr.Column():
            audio_input = gr.Audio(
                sources=["upload", "microphone"],
                type="filepath",
                label="Audio Input"
            )
            max_tokens_slider = gr.Slider(
                minimum=32,
                maximum=512,
                value=128,
                step=32,
                label="Max Output Tokens"
            )
            transcribe_btn = gr.Button("Transcribe", variant="primary")
            
        with gr.Column():
            output_text = gr.Textbox(
                label="Transcription",
                placeholder="Your transcription will appear here...",
                lines=10,
                rtl=True
            )
    
    gr.Markdown(
        """
        ### Tips:
        - For best results, use clear audio with minimal background noise
        - The model specializes in Egyptian Arabic with English code-switching
        - Recording length should be reasonable (under 30 seconds recommended)
        """
    )
    
    # Set up the transcription action
    transcribe_btn.click(
        fn=transcribe_audio,
        inputs=[audio_input, max_tokens_slider],
        outputs=output_text
    )
    
    # Also allow transcription on audio upload/record
    audio_input.change(
        fn=transcribe_audio,
        inputs=[audio_input, max_tokens_slider],
        outputs=output_text
    )

# Launch the app
if __name__ == "__main__":
    demo.launch()