itsbaivab commited on
Commit
dc6c835
·
verified ·
1 Parent(s): 8b44721

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +147 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import librosa
3
+ import librosa.display
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import torch
7
+ import torchvision.transforms as transforms
8
+ from torchvision.models import vit_b_16
9
+ import torch.nn as nn
10
+ from PIL import Image
11
+ import pickle
12
+ import os
13
+
14
+ # Set page config
15
+ st.set_page_config(
16
+ page_title="Baby Cry Analyzer",
17
+ page_icon="👶",
18
+ layout="wide"
19
+ )
20
+
21
+ # Custom CSS
22
+ st.markdown("""
23
+ <style>
24
+ .main {
25
+ padding: 2rem;
26
+ }
27
+ .stAlert {
28
+ margin-top: 1rem;
29
+ }
30
+ </style>
31
+ """, unsafe_allow_html=True)
32
+
33
+ @st.cache_resource
34
+ def load_model():
35
+ try:
36
+ # Force CPU device
37
+ device = torch.device('cpu')
38
+
39
+ # Load the model from pickle file with CPU mapping
40
+ with open("baby_cry_model.pkl", "rb") as f:
41
+ # Convert CUDA tensors to CPU during unpickling
42
+ model_state_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v
43
+ for k, v in pickle.load(f).items()}
44
+
45
+ # Initialize model architecture
46
+ model = vit_b_16(pretrained=True)
47
+ num_classes = 3 # Adjust based on your actual number of classes
48
+ model.heads.head = nn.Linear(model.heads.head.in_features, num_classes)
49
+
50
+ # Load state dict with CPU mapping
51
+ model = model.to(device)
52
+ model.load_state_dict(model_state_dict)
53
+ model.eval()
54
+
55
+ return model, device
56
+ except Exception as e:
57
+ st.error(f"""
58
+ Error loading model. Make sure the model file exists and is accessible.
59
+ If this error persists, the model might need to be re-saved for CPU compatibility.
60
+ Technical details: {str(e)}
61
+ """)
62
+ raise e
63
+
64
+ def create_spectrogram(audio_file):
65
+ # Create spectrogram
66
+ y, sr = librosa.load(audio_file, sr=22050)
67
+ mel_spec = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128)
68
+ mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
69
+
70
+ plt.figure(figsize=(5, 5))
71
+ librosa.display.specshow(mel_spec_db, sr=sr, x_axis="time", y_axis="mel")
72
+ plt.axis("off")
73
+
74
+ # Save spectrogram
75
+ temp_path = "temp_spectrogram.png"
76
+ plt.savefig(temp_path, bbox_inches="tight", pad_inches=0)
77
+ plt.close()
78
+
79
+ return temp_path
80
+
81
+ def classify_audio(model, device, spectrogram_path):
82
+ # Prepare image for classification
83
+ img = Image.open(spectrogram_path).convert("RGB")
84
+ transform = transforms.Compose([
85
+ transforms.Resize((224, 224)),
86
+ transforms.ToTensor(),
87
+ transforms.Normalize(mean=[0.5], std=[0.5])
88
+ ])
89
+ img = transform(img).unsqueeze(0).to(device)
90
+
91
+ # Classify
92
+ with torch.no_grad():
93
+ output = model(img)
94
+ predicted_class = torch.argmax(output, dim=1).item()
95
+ probabilities = torch.nn.functional.softmax(output[0], dim=0)
96
+
97
+ return predicted_class, probabilities
98
+
99
+ def main():
100
+ st.title("👶 Baby Cry Analyzer")
101
+ st.write("Upload a WAV file to analyze the type of baby cry")
102
+
103
+ # Load model
104
+ try:
105
+ model, device = load_model()
106
+ st.success("Model loaded successfully!")
107
+ except Exception as e:
108
+ st.error(f"Error loading model: {str(e)}")
109
+ return
110
+
111
+ # File upload
112
+ audio_file = st.file_uploader("Choose a WAV file", type=['wav'])
113
+
114
+ if audio_file is not None:
115
+ st.audio(audio_file)
116
+
117
+ with st.spinner("Analyzing audio..."):
118
+ # Create and display spectrogram
119
+ spec_path = create_spectrogram(audio_file)
120
+ st.image(spec_path, caption="Generated Spectrogram", width=300)
121
+
122
+ # Classify
123
+ predicted_class, probabilities = classify_audio(model, device, spec_path)
124
+
125
+ # Display results
126
+ classes = ['Belly Pain', 'Hungry', 'Tired'] # Adjust based on your classes
127
+ st.subheader("Classification Results:")
128
+
129
+ # Display prediction with confidence
130
+ col1, col2 = st.columns(2)
131
+ with col1:
132
+ st.metric("Predicted Cry Type", classes[predicted_class])
133
+ with col2:
134
+ confidence = float(probabilities[predicted_class]) * 100
135
+ st.metric("Confidence", f"{confidence:.2f}%")
136
+
137
+ # Show all probabilities
138
+ st.subheader("Probability Distribution:")
139
+ for cls, prob in zip(classes, probabilities):
140
+ st.write(f"{cls}: {float(prob)*100:.2f}%")
141
+
142
+ # Cleanup
143
+ if os.path.exists(spec_path):
144
+ os.remove(spec_path)
145
+
146
+ if __name__ == "__main__":
147
+ main()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ librosa
3
+ matplotlib
4
+ numpy
5
+ torch
6
+ torchvision
7
+ pillow