| from datetime import datetime | |
| import gradio as gr | |
| import torch | |
| import torchaudio | |
| from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model_name = "ibm-granite/granite-speech-3.3-8b" | |
| processor = AutoProcessor.from_pretrained(model_name) | |
| tokenizer = processor.tokenizer | |
| model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
| model_name, device_map=device, torch_dtype=torch.bfloat16 | |
| ) | |
| def _load_audio_mono_16k(file_path: str) -> torch.Tensor: | |
| wav, sr = torchaudio.load(file_path, normalize=True) | |
| if wav.shape[0] > 1: | |
| wav = torch.mean(wav, dim=0, keepdim=True) | |
| if sr != 16000: | |
| wav = torchaudio.functional.resample(wav, sr, 16000) | |
| return wav | |
| def process_audio(audio_path: str, instruction: str, max_tokens: int = 200) -> str: | |
| if not audio_path: | |
| return "Please upload an audio file." | |
| wav = _load_audio_mono_16k(audio_path) | |
| date_string = datetime.now().strftime("%B %d, %Y") | |
| system_prompt = ( | |
| "Knowledge Cutoff Date: April 2024.\n" | |
| f"Today's Date: {date_string}.\n" | |
| "You are Granite, developed by IBM. You are a helpful AI assistant" | |
| ) | |
| user_prompt = f"<|audio|>{instruction.strip()}" | |
| chat = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt}, | |
| ] | |
| prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) | |
| model_inputs = processor(prompt, wav, device=device, return_tensors="pt").to(device) | |
| outputs = model.generate( | |
| **model_inputs, | |
| max_new_tokens=int(max_tokens), | |
| do_sample=False, | |
| num_beams=1, | |
| ) | |
| num_input_tokens = model_inputs["input_ids"].shape[-1] | |
| new_tokens = torch.unsqueeze(outputs[0, num_input_tokens:], dim=0) | |
| text = tokenizer.batch_decode(new_tokens, add_special_tokens=False, skip_special_tokens=True)[0] | |
| return text | |
| with gr.Blocks(title="Granite Speech Demo") as demo: | |
| gr.Markdown("# Granite Speech-to-Text Demo") | |
| gr.Markdown("Upload audio and transcribe with IBM Granite.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| audio_input = gr.Audio(type="filepath", label="Upload Audio") | |
| instruction = gr.Textbox( | |
| label="Instruction", | |
| value="can you transcribe the speech into a written format?", | |
| ) | |
| max_tokens = gr.Slider(50, 1000, value=200, step=50, label="Max Output Tokens") | |
| submit_btn = gr.Button("Transcribe", variant="primary") | |
| with gr.Column(): | |
| output_text = gr.Textbox(label="Output", lines=12) | |
| submit_btn.click(process_audio, [audio_input, instruction, max_tokens], output_text) | |
| if __name__ == "__main__": | |
| demo.queue().launch(share=False, ssr_mode=False) | |