ProfRom commited on
Commit
523547d
·
verified ·
1 Parent(s): 921baab

Saar - Sanity Check

Browse files
Files changed (1) hide show
  1. app.py +75 -187
app.py CHANGED
@@ -1,210 +1,98 @@
1
- import gradio as gr
2
  import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- import torchaudio
6
- import numpy as np
7
- from datasets import load_dataset
8
-
9
-
10
- # ---------------------------
11
- # Constants
12
- # ---------------------------
13
- TARGET_SR = 44100
14
- N_FFT = 1024
15
- HOP_LENGTH = 512
16
- N_MELS = 64
17
-
18
-
19
- # ---------------------------
20
- # Load Dataset Metadata for Labels
21
- # ---------------------------
22
- dataset = load_dataset("ccmusic-database/pianos", name="8_class")
23
- label_names = dataset["train"].features["label"].names
24
- num_classes = len(label_names)
25
-
26
-
27
- # ---------------------------
28
- # Define the Same CNN Model as in Training
29
- # ---------------------------
30
- class PianoCNNMultiTask(nn.Module):
31
- def __init__(self, num_classes):
32
- super().__init__()
33
- self.features = nn.Sequential(
34
- nn.Conv2d(3, 16, kernel_size=3, padding=1),
35
- nn.BatchNorm2d(16),
36
- nn.ReLU(),
37
- nn.MaxPool2d(2), # 128 -> 64
38
-
39
-
40
- nn.Conv2d(16, 32, kernel_size=3, padding=1),
41
- nn.BatchNorm2d(32),
42
- nn.ReLU(),
43
- nn.MaxPool2d(2), # 64 -> 32
44
-
45
-
46
- nn.Conv2d(32, 64, kernel_size=3, padding=1),
47
- nn.BatchNorm2d(64),
48
- nn.ReLU(),
49
- nn.MaxPool2d(2), # 32 -> 16
50
-
51
-
52
- nn.Conv2d(64, 128, kernel_size=3, padding=1),
53
- nn.BatchNorm2d(128),
54
- nn.ReLU(),
55
- nn.AdaptiveAvgPool2d((4, 4)) # 4x4 feature map
56
- )
57
- self.flatten = nn.Flatten()
58
- self.fc_shared = nn.Linear(128 * 4 * 4, 256)
59
- self.dropout = nn.Dropout(0.3)
60
-
61
-
62
- # Classification head
63
- self.fc_class = nn.Linear(256, num_classes)
64
- # Regression head (quality score)
65
- self.fc_reg = nn.Linear(256, 1)
66
-
67
-
68
- def forward(self, x):
69
- x = self.features(x)
70
- x = self.flatten(x)
71
- x = F.relu(self.fc_shared(x))
72
- x = self.dropout(x)
73
- class_logits = self.fc_class(x)
74
- quality_pred = self.fc_reg(x).squeeze(1)
75
- return class_logits, quality_pred
76
-
77
-
78
- # ---------------------------
79
- # Initialize and Load Trained Model (CPU)
80
- # ---------------------------
81
- model = PianoCNNMultiTask(num_classes=num_classes)
82
- state_dict = torch.load("piano_cnn_multitask.pt", map_location=torch.device("cpu"))
83
- model.load_state_dict(state_dict)
84
- model.eval() # inference mode
85
-
86
-
87
- # ---------------------------
88
- # Audio Preprocessing
89
- # ---------------------------
90
- mel_transform = torchaudio.transforms.MelSpectrogram(
91
- sample_rate=TARGET_SR,
92
- n_fft=N_FFT,
93
- hop_length=HOP_LENGTH,
94
- n_mels=N_MELS,
95
- center=False # we will handle padding manually
96
  )
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
- def preprocess_audio_to_mel_image(audio):
100
- """
101
- audio from gradio.Audio(type="numpy") is (sample_rate, data)
102
- Returns a 3x128x128 tensor ready for the CNN.
103
- """
104
- sr, data = audio
105
-
106
-
107
- # Convert to tensor
108
- waveform = torch.tensor(data, dtype=torch.float32)
109
-
110
-
111
- # If shape is (samples,), make it (1, samples)
112
- if waveform.ndim == 1:
113
- waveform = waveform.unsqueeze(0)
114
-
115
-
116
- # If shape is (samples, channels), transpose to (channels, samples)
117
- if waveform.ndim == 2 and waveform.shape[0] < waveform.shape[1]:
118
- waveform = waveform.transpose(0, 1)
119
-
120
-
121
- # Convert to mono if stereo
122
- if waveform.shape[0] > 1:
123
- waveform = waveform.mean(dim=0, keepdim=True)
124
-
125
-
126
- # Resample to TARGET_SR if needed
127
- if sr != TARGET_SR:
128
- resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=TARGET_SR)
129
- waveform = resampler(waveform)
130
-
131
-
132
- # Ensure minimum length for STFT
133
- min_len = N_FFT
134
- if waveform.shape[-1] < min_len:
135
- pad_amount = min_len - waveform.shape[-1]
136
- waveform = F.pad(waveform, (0, pad_amount))
137
-
138
-
139
- # Compute Mel-spectrogram and convert to dB
140
- mel = mel_transform(waveform) # [1, n_mels, time]
141
- mel_db = torchaudio.transforms.AmplitudeToDB()(mel)
142
-
143
-
144
- # Normalize to 0–1
145
- mel_db = (mel_db - mel_db.min()) / (mel_db.max() - mel_db.min() + 1e-6)
146
-
147
-
148
- # Resize to 128x128 and make 3 channels
149
- mel_db = mel_db.unsqueeze(0) # [1, 1, H, W]
150
- mel_resized = F.interpolate(mel_db, size=(128, 128), mode="bilinear", align_corners=False)
151
- mel_rgb = mel_resized.repeat(1, 3, 1, 1) # [1, 3, 128, 128]
152
-
153
-
154
- return mel_rgb.squeeze(0) # [3, 128, 128]
155
 
 
 
 
 
 
156
 
157
- # ---------------------------
158
- # Main Inference Function
159
- # ---------------------------
160
- def analyze_piano(audio):
161
- if audio is None:
162
- return "Please upload or record a piano audio clip (around 1–3 seconds)."
163
 
 
 
 
 
 
164
 
 
165
  try:
166
- # Preprocess input
167
- mel_img = preprocess_audio_to_mel_image(audio) # [3,128,128]
168
- mel_batch = mel_img.unsqueeze(0) # [1,3,128,128]
169
-
 
 
170
 
171
- with torch.no_grad():
172
- logits, q_pred = model(mel_batch)
173
- class_idx = torch.argmax(logits, dim=1).item()
174
- quality_score = float(q_pred.item())
175
 
 
 
 
176
 
177
- piano_type = label_names[class_idx]
178
- quality_score_rounded = round(quality_score, 2)
 
179
 
 
 
 
 
 
 
 
180
 
181
- output_text = (
182
- f"Piano Type Prediction: {piano_type}\n"
183
- f"Estimated Sound Quality Score: {quality_score_rounded} / 10"
184
- )
185
- return output_text
186
 
 
187
 
188
  except Exception as e:
189
- return f"An error occurred while processing the audio: {e}"
 
 
190
 
191
-
192
- # ---------------------------
193
- # Gradio Interface
194
- # ---------------------------
195
  demo = gr.Interface(
196
- fn=analyze_piano,
197
- inputs=gr.Audio(
198
- sources=["upload", "microphone"],
199
- type="numpy",
200
- label="Upload Piano Audio or Record with Microphone"
 
 
 
 
 
 
201
  ),
202
- outputs=gr.Textbox(label="AI Analysis Output"),
203
- title="AI Piano Sound Analyzer 🎹",
204
- description="Upload a short piano recording to get a predicted piano type and estimated sound-quality score from the trained CNN model."
205
  )
206
 
207
-
208
  if __name__ == "__main__":
209
- demo.launch()
210
-
 
 
1
  import torch
2
+ import gradio as gr
3
+ from PIL import Image
4
+ from transformers import (
5
+ BlipProcessor,
6
+ BlipForConditionalGeneration,
7
+ pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  )
9
 
10
+ # Select device
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
13
+
14
+ # Load BLIP captioning model directly
15
+ processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
16
+ caption_model = BlipForConditionalGeneration.from_pretrained(
17
+ "Salesforce/blip-image-captioning-base",
18
+ torch_dtype=torch_dtype
19
+ ).to(device)
20
+
21
+ # Load image classification model
22
+ classifier = pipeline(
23
+ task="image-classification",
24
+ model="google/vit-base-patch16-224",
25
+ device=0 if torch.cuda.is_available() else -1
26
+ )
27
 
28
+ print("Models loaded successfully.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
+ def generate_caption(image):
31
+ inputs = processor(images=image, return_tensors="pt").to(device, torch_dtype)
32
+ output = caption_model.generate(**inputs, max_new_tokens=30)
33
+ caption = processor.decode(output[0], skip_special_tokens=True)
34
+ return caption
35
 
36
+ def build_summary(caption: str, label: str) -> str:
37
+ caption = caption.strip() if caption else "No caption available"
38
+ label = label.strip() if label else "unknown object"
 
 
 
39
 
40
+ return (
41
+ f"The captioning model describes the image as: {caption}. "
42
+ f"The image classification model identifies the main subject as: {label}. "
43
+ f"Taken together, the image appears to focus on this subject or scene."
44
+ )
45
 
46
+ def analyze_image(image):
47
  try:
48
+ if image is None:
49
+ return (
50
+ "Please upload an image.",
51
+ "No classification available.",
52
+ "Please upload an image first."
53
+ )
54
 
55
+ image = image.convert("RGB")
 
 
 
56
 
57
+ # Captioning
58
+ caption = generate_caption(image)
59
+ print("CAPTION RESULT:", caption)
60
 
61
+ # Classification
62
+ class_result = classifier(image)
63
+ print("CLASSIFICATION RESULT:", class_result)
64
 
65
+ if isinstance(class_result, list) and len(class_result) > 0:
66
+ top_label = class_result[0].get("label", "Unknown")
67
+ top_score = class_result[0].get("score", 0.0)
68
+ classification_text = f"{top_label} (confidence: {top_score:.4f})"
69
+ else:
70
+ top_label = "Unknown"
71
+ classification_text = "No classification generated."
72
 
73
+ summary = build_summary(caption, top_label)
 
 
 
 
74
 
75
+ return caption, classification_text, summary
76
 
77
  except Exception as e:
78
+ print("ERROR:", str(e))
79
+ error_text = f"Error: {str(e)}"
80
+ return error_text, error_text, error_text
81
 
 
 
 
 
82
  demo = gr.Interface(
83
+ fn=analyze_image,
84
+ inputs=gr.Image(type="pil", label="Upload an Image"),
85
+ outputs=[
86
+ gr.Textbox(label="Generated Caption"),
87
+ gr.Textbox(label="Top Classification"),
88
+ gr.Textbox(label="Combined Summary", lines=4)
89
+ ],
90
+ title="Image Captioning, Classification, and Summary App",
91
+ description=(
92
+ "Upload an image to generate an automatic caption, predict the main image class, "
93
+ "and produce a short combined summary."
94
  ),
 
 
 
95
  )
96
 
 
97
  if __name__ == "__main__":
98
+ demo.launch()