balaharan commited on
Commit
e95056a
Β·
verified Β·
1 Parent(s): a00d269

requierement.txt

Browse files

transformers>=4.45.0
torch>=2.0.0
torchaudio>=2.0.0
gradio>=4.0.0
soundfile>=0.12.0
accelerate>=0.21.0

Files changed (1) hide show
  1. app.py +81 -101
app.py CHANGED
@@ -1,75 +1,92 @@
1
  import gradio as gr
2
  import torch
3
  import torchaudio
4
- from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
5
- import numpy as np
6
 
7
- # Global variables to store model and processor
 
 
 
8
  model = None
9
  processor = None
10
  device = None
11
 
12
  def load_model():
13
- """Load the Granite Speech model and processor"""
14
  global model, processor, device
15
 
16
  try:
 
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
18
  model_name = "ibm-granite/granite-speech-3.3-2b"
19
 
20
- # Load processor and model
 
21
  processor = AutoProcessor.from_pretrained(model_name)
22
- model = AutoModelForSpeechSeq2Seq.from_pretrained(model_name).to(device)
23
 
24
- return f"βœ… Model loaded successfully on {device}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  except Exception as e:
26
  return f"❌ Error loading model: {str(e)}"
27
 
28
- def transcribe_audio(audio_file, task_type="transcribe"):
29
- """
30
- Transcribe audio using Granite Speech model
31
-
32
- Args:
33
- audio_file: Audio file path from Gradio
34
- task_type: "transcribe" or "translate"
35
- """
36
  global model, processor, device
37
 
38
  if model is None or processor is None:
39
- return "❌ Model not loaded. Please load the model first."
 
 
 
40
 
41
  try:
42
  # Load and preprocess audio
43
- if audio_file is None:
44
- return "❌ Please upload an audio file"
45
-
46
- # Load audio file
47
  wav, sr = torchaudio.load(audio_file)
48
 
49
- # Ensure mono and 16kHz
50
  if wav.shape[0] > 1:
51
- wav = wav.mean(dim=0, keepdim=True) # Convert to mono
 
 
52
  if sr != 16000:
53
  resampler = torchaudio.transforms.Resample(sr, 16000)
54
  wav = resampler(wav)
55
 
56
- # Normalize audio
57
- wav = torchaudio.functional.normalize_audio(wav)
58
-
59
- # Create chat template
60
- if task_type == "transcribe":
61
- user_content = "<|audio|>can you transcribe the speech into a written format?"
62
- else: # translate
63
- user_content = "<|audio|>can you translate this speech to English?"
64
 
 
65
  chat = [
66
  {
67
  "role": "system",
68
- "content": "Knowledge Cutoff Date: April 2024.\nToday's Date: April 9, 2025.\nYou are Granite, developed by IBM. You are a helpful AI assistant",
69
  },
70
  {
71
  "role": "user",
72
- "content": user_content,
73
  }
74
  ]
75
 
@@ -83,116 +100,79 @@ def transcribe_audio(audio_file, task_type="transcribe"):
83
  model_inputs = processor(
84
  text,
85
  wav,
86
- device=device,
87
  return_tensors="pt",
 
88
  ).to(device)
89
 
90
- # Generate transcription
91
  with torch.no_grad():
92
- model_outputs = model.generate(
93
  **model_inputs,
94
- max_new_tokens=200,
95
- num_beams=4,
96
  do_sample=False,
97
- min_length=1,
98
- top_p=1.0,
99
- repetition_penalty=1.0,
100
- length_penalty=1.0,
101
  temperature=1.0,
102
- bos_token_id=tokenizer.bos_token_id,
103
- eos_token_id=tokenizer.eos_token_id,
104
  pad_token_id=tokenizer.pad_token_id,
105
  )
106
 
107
  # Decode output
108
  num_input_tokens = model_inputs["input_ids"].shape[-1]
109
- new_tokens = model_outputs[0, num_input_tokens:].unsqueeze(0)
110
- output_text = tokenizer.batch_decode(
111
- new_tokens, add_special_tokens=False, skip_special_tokens=True
112
  )[0]
113
 
114
- return f"βœ… {task_type.capitalize()} Result:\n\n{output_text}"
115
 
116
  except Exception as e:
117
- return f"❌ Error during {task_type}: {str(e)}"
118
 
119
- def create_interface():
120
- """Create the Gradio interface"""
121
-
122
- with gr.Blocks(title="Granite Speech 3.3-2B Demo", theme=gr.themes.Soft()) as demo:
123
  gr.Markdown("""
124
  # 🎀 IBM Granite Speech 3.3-2B Demo
125
 
126
- This demo uses IBM's Granite Speech 3.3-2B model for automatic speech recognition (ASR) and speech translation.
127
 
128
- **Supported Languages**: English, French, German, Spanish, Portuguese
129
-
130
- **Features**:
131
- - πŸ“ Speech-to-text transcription
132
- - 🌍 Speech translation to English
133
- - πŸ”„ Two-pass design for improved accuracy
134
  """)
135
 
136
  with gr.Row():
137
  with gr.Column():
138
- # Model loading section
139
- gr.Markdown("### 1. Load Model")
140
- load_btn = gr.Button("πŸ”„ Load Granite Speech Model", variant="primary")
141
- load_status = gr.Textbox(label="Status", interactive=False)
142
 
143
- # Audio input section
144
- gr.Markdown("### 2. Upload Audio")
145
- audio_input = gr.Audio(
146
  label="Upload Audio File",
147
  type="filepath",
148
  format="wav"
149
  )
150
 
151
- # Task selection
152
- task_choice = gr.Radio(
153
- choices=["transcribe", "translate"],
154
- value="transcribe",
155
- label="Task",
156
- info="Choose whether to transcribe or translate to English"
157
- )
158
-
159
- # Process button
160
- process_btn = gr.Button("🎯 Process Audio", variant="secondary")
161
 
162
  with gr.Column():
163
- # Output section
164
- gr.Markdown("### 3. Results")
165
- output_text = gr.Textbox(
166
- label="Output",
167
- lines=10,
168
- interactive=False,
169
- placeholder="Transcription or translation will appear here..."
170
  )
171
 
172
- # Example audio section
173
  gr.Markdown("""
174
- ### πŸ“‹ Usage Tips:
175
- - **Audio format**: Upload WAV, MP3, or other common audio formats
176
- - **Quality**: Clear speech works best (16kHz recommended)
177
- - **Length**: Keep audio clips reasonable in length for free tier
178
- - **Languages**: Works with English, French, German, Spanish, Portuguese
179
  """)
180
 
181
  # Event handlers
182
- load_btn.click(
183
- fn=load_model,
184
- outputs=load_status
185
- )
186
-
187
- process_btn.click(
188
- fn=transcribe_audio,
189
- inputs=[audio_input, task_choice],
190
- outputs=output_text
191
- )
192
 
193
  return demo
194
 
195
- # Create and launch the interface
196
  if __name__ == "__main__":
197
- demo = create_interface()
198
  demo.launch()
 
1
  import gradio as gr
2
  import torch
3
  import torchaudio
4
+ import warnings
5
+ import os
6
 
7
+ # Suppress warnings for cleaner output
8
+ warnings.filterwarnings("ignore")
9
+
10
+ # Global variables
11
  model = None
12
  processor = None
13
  device = None
14
 
15
  def load_model():
16
+ """Load the Granite Speech model with error handling"""
17
  global model, processor, device
18
 
19
  try:
20
+ # Check available device
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
+ print(f"Using device: {device}")
23
+
24
+ # Import here to catch import errors
25
+ from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
26
+
27
  model_name = "ibm-granite/granite-speech-3.3-2b"
28
 
29
+ # Load with memory optimization for free tier
30
+ print("Loading processor...")
31
  processor = AutoProcessor.from_pretrained(model_name)
 
32
 
33
+ print("Loading model...")
34
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
35
+ model_name,
36
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
37
+ low_cpu_mem_usage=True,
38
+ ).to(device)
39
+
40
+ # Set to eval mode
41
+ model.eval()
42
+
43
+ return f"βœ… Model loaded successfully on {device}!"
44
+
45
+ except ImportError as e:
46
+ return f"❌ Import error: {str(e)}. Please check requirements.txt"
47
+ except torch.cuda.OutOfMemoryError:
48
+ return "❌ GPU out of memory. Try restarting the Space or use CPU."
49
  except Exception as e:
50
  return f"❌ Error loading model: {str(e)}"
51
 
52
+ def transcribe_audio(audio_file):
53
+ """Simple transcription function"""
 
 
 
 
 
 
54
  global model, processor, device
55
 
56
  if model is None or processor is None:
57
+ return "❌ Please load the model first by clicking 'Load Model' button."
58
+
59
+ if audio_file is None:
60
+ return "❌ Please upload an audio file."
61
 
62
  try:
63
  # Load and preprocess audio
 
 
 
 
64
  wav, sr = torchaudio.load(audio_file)
65
 
66
+ # Convert to mono if stereo
67
  if wav.shape[0] > 1:
68
+ wav = wav.mean(dim=0, keepdim=True)
69
+
70
+ # Resample to 16kHz if needed
71
  if sr != 16000:
72
  resampler = torchaudio.transforms.Resample(sr, 16000)
73
  wav = resampler(wav)
74
 
75
+ # Limit audio length for free tier (30 seconds max)
76
+ max_length = 16000 * 30 # 30 seconds at 16kHz
77
+ if wav.shape[1] > max_length:
78
+ wav = wav[:, :max_length]
79
+ print("Audio truncated to 30 seconds for processing")
 
 
 
80
 
81
+ # Create simple chat template
82
  chat = [
83
  {
84
  "role": "system",
85
+ "content": "You are Granite, developed by IBM. You are a helpful AI assistant.",
86
  },
87
  {
88
  "role": "user",
89
+ "content": "<|audio|>Please transcribe this audio.",
90
  }
91
  ]
92
 
 
100
  model_inputs = processor(
101
  text,
102
  wav,
 
103
  return_tensors="pt",
104
+ sampling_rate=16000
105
  ).to(device)
106
 
107
+ # Generate with conservative settings
108
  with torch.no_grad():
109
+ outputs = model.generate(
110
  **model_inputs,
111
+ max_new_tokens=100,
112
+ num_beams=2, # Reduced for speed
113
  do_sample=False,
 
 
 
 
114
  temperature=1.0,
 
 
115
  pad_token_id=tokenizer.pad_token_id,
116
  )
117
 
118
  # Decode output
119
  num_input_tokens = model_inputs["input_ids"].shape[-1]
120
+ new_tokens = outputs[0, num_input_tokens:].unsqueeze(0)
121
+ transcription = tokenizer.batch_decode(
122
+ new_tokens, skip_special_tokens=True
123
  )[0]
124
 
125
+ return f"🎀 Transcription:\n\n{transcription}"
126
 
127
  except Exception as e:
128
+ return f"❌ Error during transcription: {str(e)}"
129
 
130
+ # Create Gradio interface
131
+ def create_demo():
132
+ with gr.Blocks(title="Granite Speech Demo", theme=gr.themes.Soft()) as demo:
 
133
  gr.Markdown("""
134
  # 🎀 IBM Granite Speech 3.3-2B Demo
135
 
136
+ Upload an audio file to transcribe speech to text.
137
 
138
+ **Supported**: English, French, German, Spanish, Portuguese
 
 
 
 
 
139
  """)
140
 
141
  with gr.Row():
142
  with gr.Column():
143
+ # Model loading
144
+ load_btn = gr.Button("πŸ”„ Load Model", variant="primary", size="lg")
145
+ status = gr.Textbox(label="Status", interactive=False)
 
146
 
147
+ # Audio input
148
+ audio = gr.Audio(
 
149
  label="Upload Audio File",
150
  type="filepath",
151
  format="wav"
152
  )
153
 
154
+ transcribe_btn = gr.Button("🎯 Transcribe", variant="secondary")
 
 
 
 
 
 
 
 
 
155
 
156
  with gr.Column():
157
+ output = gr.Textbox(
158
+ label="Transcription Result",
159
+ lines=8,
160
+ interactive=False
 
 
 
161
  )
162
 
 
163
  gr.Markdown("""
164
+ ### πŸ’‘ Tips:
165
+ - Keep audio files under 30 seconds for free tier
166
+ - Clear speech works best
167
+ - WAV format recommended
 
168
  """)
169
 
170
  # Event handlers
171
+ load_btn.click(load_model, outputs=status)
172
+ transcribe_btn.click(transcribe_audio, inputs=audio, outputs=output)
 
 
 
 
 
 
 
 
173
 
174
  return demo
175
 
 
176
  if __name__ == "__main__":
177
+ demo = create_demo()
178
  demo.launch()