File size: 3,907 Bytes
cbf450b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# -*- coding: utf-8 -*-
"""ASR_Deployment.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1MmePYOn1Ho2FhILi00u9UbvsujEoHhot
"""

import gradio as gr
from transformers import WhisperForConditionalGeneration, WhisperProcessor, GenerationConfig
import torch
import librosa
import os

# --- 1. CONFIGURATION ---
# Note: Ensure your token has "Read" access to the repository
MODEL_PATH = "MaryWambo/whisper-base-kikuyu4"
device = "cuda" if torch.cuda.is_available() else "cpu"

# --- 2. LOAD MODEL & PROCESSOR ---
print(f"Loading model to {device}...")
try:
    processor = WhisperProcessor.from_pretrained(MODEL_PATH)
    model = WhisperForConditionalGeneration.from_pretrained(MODEL_PATH).to(device)

    # Define Generation Config to avoid "outdated" errors
    # We set language and task here so they don't conflict in the generate() call
    gen_config = GenerationConfig.from_pretrained(MODEL_PATH)
    gen_config.language = "swahili" # Using full name or "sw" depending on how it was trained
    gen_config.task = "transcribe"
    gen_config.forced_decoder_ids = None
    gen_config.suppress_tokens = []

    model.generation_config = gen_config

except Exception as e:
    print(f"Error loading model: {e}")

# --- 3. CUSTOM CSS ---
custom_css = """
body, .gradio-container { background-color: white !important; }
#title-text h1 { color: #8b0000 !important; font-weight: 900 !important; text-align: center; }
.upload-button svg, .mic-button svg, .clear-button svg, .record-button svg {
    transform: scale(1.5) !important;
    color: #8b0000 !important;
}
#predict-box textarea {
    font-size: 1.6rem !important;
    font-weight: 800 !important;
    color: #000000 !important;
    border: 3px solid #8b0000 !important;
}
#run-btn {
    background: #8b0000 !important;
    color: white !important;
    font-weight: bold !important;
    font-size: 1.4rem !important;
}
"""

# --- 4. LOGIC FUNCTIONS ---
def transcribe_kikuyu(audio):
    if audio is None:
        return "Please record or upload audio."

    try:
        # Load audio and resample to 16kHz (standard for Whisper)
        speech_array, sampling_rate = librosa.load(audio, sr=16000)

        # Process audio features
        inputs = processor(speech_array, sampling_rate=sampling_rate, return_tensors="pt")
        input_features = inputs.input_features.to(device)

        with torch.no_grad():
            # We no longer pass 'language' or 'task' here because
            # they are already defined in model.generation_config
            generated_ids = model.generate(
                input_features=input_features,
                num_beams=5,
                max_new_tokens=255
            )

        # Decode the predicted IDs to text
        prediction = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        return prediction

    except Exception as e:
        return f"Error during transcription: {str(e)}"

# --- 5. BUILD GRADIO UI ---
with gr.Blocks(theme=gr.themes.Default(), css=custom_css) as demo:
    gr.Markdown("# 🎙️ Kikuyu ASR ", elem_id="title-text")

    with gr.Row():
        with gr.Column(scale=1):
            audio_input = gr.Audio(
                sources=["microphone", "upload"],
                type="filepath",
                label="🎤 Record/Upload Kikuyu Speech"
            )
            submit_btn = gr.Button("🚀 RUN TRANSCRIPTION", elem_id="run-btn")

        with gr.Column(scale=1):
            text_out = gr.Textbox(
                label="🤖 AI Prediction",
                elem_id="predict-box",
                lines=8
            )

    submit_btn.click(
        fn=transcribe_kikuyu,
        inputs=[audio_input],
        outputs=[text_out]
    )

# --- 6. LAUNCH ---
if __name__ == "__main__":
    # share=True creates a public URL valid for 72 hours
    demo.launch(share=True, debug=True)