ProfRom commited on
Commit
555f48a
·
verified ·
1 Parent(s): 647e467

House - Unit 8 Assignment

Browse files
Files changed (1) hide show
  1. app.py +199 -43
app.py CHANGED
@@ -1,54 +1,210 @@
1
- from transformers import pipeline
2
- from PIL import Image
3
  import gradio as gr
 
 
 
 
 
 
4
 
5
- # VQA pipeline
6
- vqa_pipeline = pipeline(
7
- "visual-question-answering",
8
- model="dandelin/vilt-b32-finetuned-vqa"
9
- )
10
 
11
- # English -> Korean translator (reliable alternative)
12
- translator = pipeline(
13
- "translation",
14
- model="facebook/m2m100_418M"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  )
16
 
17
- def vqa_answer(image: Image.Image, question: str):
18
- if image is None:
19
- return "Please upload an image."
20
- if not question or not question.strip():
21
- return "Please enter a question about the image."
22
-
23
- # VQA
24
- result = vqa_pipeline(image=image, question=question)
25
- top = result[0]
26
- answer = top["answer"]
27
- score = top.get("score", None)
28
- score_str = f"{score:.3f}" if isinstance(score, (float, int)) else "N/A"
29
-
30
- # Translate EN → KO
31
- translated = translator(
32
- answer,
33
- src_lang="en",
34
- tgt_lang="ko"
35
- )[0]["translation_text"]
36
-
37
- return (
38
- f"Answer (EN): {answer} (score: {score_str})\n\n"
39
- f"번역 (KO): {translated}"
40
- )
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  demo = gr.Interface(
43
- fn=vqa_answer,
44
- inputs=[
45
- gr.Image(type="pil", label="Upload an image"),
46
- gr.Textbox(lines=2, label="Question about the image")
47
- ],
48
- outputs=gr.Textbox(label="VQA Result (English + Korean)"),
49
- title="Visual Question Answering + Korean Translation",
50
- description="Upload an image, ask a question, and see the answer in English and Korean."
 
51
  )
52
 
 
53
  if __name__ == "__main__":
54
  demo.launch()
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torchaudio
6
+ import numpy as np
7
+ from datasets import load_dataset
8
 
 
 
 
 
 
9
 
10
+ # ---------------------------
11
+ # Constants
12
+ # ---------------------------
13
+ TARGET_SR = 44100
14
+ N_FFT = 1024
15
+ HOP_LENGTH = 512
16
+ N_MELS = 64
17
+
18
+
19
+ # ---------------------------
20
+ # Load Dataset Metadata for Labels
21
+ # ---------------------------
22
+ dataset = load_dataset("ccmusic-database/pianos", name="8_class")
23
+ label_names = dataset["train"].features["label"].names
24
+ num_classes = len(label_names)
25
+
26
+
27
+ # ---------------------------
28
+ # Define the Same CNN Model as in Training
29
+ # ---------------------------
30
+ class PianoCNNMultiTask(nn.Module):
31
+ def __init__(self, num_classes):
32
+ super().__init__()
33
+ self.features = nn.Sequential(
34
+ nn.Conv2d(3, 16, kernel_size=3, padding=1),
35
+ nn.BatchNorm2d(16),
36
+ nn.ReLU(),
37
+ nn.MaxPool2d(2), # 128 -> 64
38
+
39
+
40
+ nn.Conv2d(16, 32, kernel_size=3, padding=1),
41
+ nn.BatchNorm2d(32),
42
+ nn.ReLU(),
43
+ nn.MaxPool2d(2), # 64 -> 32
44
+
45
+
46
+ nn.Conv2d(32, 64, kernel_size=3, padding=1),
47
+ nn.BatchNorm2d(64),
48
+ nn.ReLU(),
49
+ nn.MaxPool2d(2), # 32 -> 16
50
+
51
+
52
+ nn.Conv2d(64, 128, kernel_size=3, padding=1),
53
+ nn.BatchNorm2d(128),
54
+ nn.ReLU(),
55
+ nn.AdaptiveAvgPool2d((4, 4)) # 4x4 feature map
56
+ )
57
+ self.flatten = nn.Flatten()
58
+ self.fc_shared = nn.Linear(128 * 4 * 4, 256)
59
+ self.dropout = nn.Dropout(0.3)
60
+
61
+
62
+ # Classification head
63
+ self.fc_class = nn.Linear(256, num_classes)
64
+ # Regression head (quality score)
65
+ self.fc_reg = nn.Linear(256, 1)
66
+
67
+
68
+ def forward(self, x):
69
+ x = self.features(x)
70
+ x = self.flatten(x)
71
+ x = F.relu(self.fc_shared(x))
72
+ x = self.dropout(x)
73
+ class_logits = self.fc_class(x)
74
+ quality_pred = self.fc_reg(x).squeeze(1)
75
+ return class_logits, quality_pred
76
+
77
+
78
+ # ---------------------------
79
+ # Initialize and Load Trained Model (CPU)
80
+ # ---------------------------
81
+ model = PianoCNNMultiTask(num_classes=num_classes)
82
+ state_dict = torch.load("piano_cnn_multitask.pt", map_location=torch.device("cpu"))
83
+ model.load_state_dict(state_dict)
84
+ model.eval() # inference mode
85
+
86
+
87
+ # ---------------------------
88
+ # Audio Preprocessing
89
+ # ---------------------------
90
+ mel_transform = torchaudio.transforms.MelSpectrogram(
91
+ sample_rate=TARGET_SR,
92
+ n_fft=N_FFT,
93
+ hop_length=HOP_LENGTH,
94
+ n_mels=N_MELS,
95
+ center=False # we will handle padding manually
96
  )
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
+ def preprocess_audio_to_mel_image(audio):
100
+ """
101
+ audio from gradio.Audio(type="numpy") is (sample_rate, data)
102
+ Returns a 3x128x128 tensor ready for the CNN.
103
+ """
104
+ sr, data = audio
105
+
106
+
107
+ # Convert to tensor
108
+ waveform = torch.tensor(data, dtype=torch.float32)
109
+
110
+
111
+ # If shape is (samples,), make it (1, samples)
112
+ if waveform.ndim == 1:
113
+ waveform = waveform.unsqueeze(0)
114
+
115
+
116
+ # If shape is (samples, channels), transpose to (channels, samples)
117
+ if waveform.ndim == 2 and waveform.shape[0] < waveform.shape[1]:
118
+ waveform = waveform.transpose(0, 1)
119
+
120
+
121
+ # Convert to mono if stereo
122
+ if waveform.shape[0] > 1:
123
+ waveform = waveform.mean(dim=0, keepdim=True)
124
+
125
+
126
+ # Resample to TARGET_SR if needed
127
+ if sr != TARGET_SR:
128
+ resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=TARGET_SR)
129
+ waveform = resampler(waveform)
130
+
131
+
132
+ # Ensure minimum length for STFT
133
+ min_len = N_FFT
134
+ if waveform.shape[-1] < min_len:
135
+ pad_amount = min_len - waveform.shape[-1]
136
+ waveform = F.pad(waveform, (0, pad_amount))
137
+
138
+
139
+ # Compute Mel-spectrogram and convert to dB
140
+ mel = mel_transform(waveform) # [1, n_mels, time]
141
+ mel_db = torchaudio.transforms.AmplitudeToDB()(mel)
142
+
143
+
144
+ # Normalize to 0–1
145
+ mel_db = (mel_db - mel_db.min()) / (mel_db.max() - mel_db.min() + 1e-6)
146
+
147
+
148
+ # Resize to 128x128 and make 3 channels
149
+ mel_db = mel_db.unsqueeze(0) # [1, 1, H, W]
150
+ mel_resized = F.interpolate(mel_db, size=(128, 128), mode="bilinear", align_corners=False)
151
+ mel_rgb = mel_resized.repeat(1, 3, 1, 1) # [1, 3, 128, 128]
152
+
153
+
154
+ return mel_rgb.squeeze(0) # [3, 128, 128]
155
+
156
+
157
+ # ---------------------------
158
+ # Main Inference Function
159
+ # ---------------------------
160
+ def analyze_piano(audio):
161
+ if audio is None:
162
+ return "Please upload or record a piano audio clip (around 1–3 seconds)."
163
+
164
+
165
+ try:
166
+ # Preprocess input
167
+ mel_img = preprocess_audio_to_mel_image(audio) # [3,128,128]
168
+ mel_batch = mel_img.unsqueeze(0) # [1,3,128,128]
169
+
170
+
171
+ with torch.no_grad():
172
+ logits, q_pred = model(mel_batch)
173
+ class_idx = torch.argmax(logits, dim=1).item()
174
+ quality_score = float(q_pred.item())
175
+
176
+
177
+ piano_type = label_names[class_idx]
178
+ quality_score_rounded = round(quality_score, 2)
179
+
180
+
181
+ output_text = (
182
+ f"Piano Type Prediction: {piano_type}\n"
183
+ f"Estimated Sound Quality Score: {quality_score_rounded} / 10"
184
+ )
185
+ return output_text
186
+
187
+
188
+ except Exception as e:
189
+ return f"An error occurred while processing the audio: {e}"
190
+
191
+
192
+ # ---------------------------
193
+ # Gradio Interface
194
+ # ---------------------------
195
  demo = gr.Interface(
196
+ fn=analyze_piano,
197
+ inputs=gr.Audio(
198
+ sources=["upload", "microphone"],
199
+ type="numpy",
200
+ label="Upload Piano Audio or Record with Microphone"
201
+ ),
202
+ outputs=gr.Textbox(label="AI Analysis Output"),
203
+ title="AI Piano Sound Analyzer 🎹",
204
+ description="Upload a short piano recording to get a predicted piano type and estimated sound-quality score from the trained CNN model."
205
  )
206
 
207
+
208
  if __name__ == "__main__":
209
  demo.launch()
210
+