SeenSiravit commited on
Commit
d7bb333
·
verified ·
1 Parent(s): 5113b14

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -2
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import numpy as np
2
-
3
  from PIL import Image, ImageFilter
4
 
5
  import torch
@@ -13,6 +12,9 @@ from fastapi.responses import JSONResponse
13
  import uvicorn
14
  from fastapi.middleware.cors import CORSMiddleware
15
 
 
 
 
16
  import datetime
17
  import pytz
18
 
@@ -126,7 +128,7 @@ async def predict(file: UploadFile = File(...)):
126
  # image.save('input.jpg')
127
 
128
  image_rz = image.resize((256,256))
129
- # image_rz.save('input_resize.jpg')
130
 
131
  print(f"image after fill bg : {image.size, type(image)}\n")
132
 
@@ -136,10 +138,114 @@ async def predict(file: UploadFile = File(...)):
136
  parkinson_predict = predict_parkinson(image)
137
  print(f"parkinson predict : {parkinson_predict}\n")
138
 
 
139
  print(f"end time : {datetime.datetime.now(pytz.timezone('Asia/Bangkok'))}")
140
 
141
  return JSONResponse([spiral_predict, parkinson_predict])
142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  if __name__ == "__main__":
144
  import uvicorn
145
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
  import numpy as np
 
2
  from PIL import Image, ImageFilter
3
 
4
  import torch
 
12
  import uvicorn
13
  from fastapi.middleware.cors import CORSMiddleware
14
 
15
+ from numpy.fft import rfft, irfft, rfftfreq
16
+ import pandas as pd
17
+
18
  import datetime
19
  import pytz
20
 
 
128
  # image.save('input.jpg')
129
 
130
  image_rz = image.resize((256,256))
131
+ # image_rz.save('input1.jpg')
132
 
133
  print(f"image after fill bg : {image.size, type(image)}\n")
134
 
 
138
  parkinson_predict = predict_parkinson(image)
139
  print(f"parkinson predict : {parkinson_predict}\n")
140
 
141
+ curr_time = datetime.datetime.now()
142
  print(f"end time : {datetime.datetime.now(pytz.timezone('Asia/Bangkok'))}")
143
 
144
  return JSONResponse([spiral_predict, parkinson_predict])
145
 
146
+ SAMPLING_RATE = 100
147
+ DURATION = 10
148
+ time_axis = np.linspace(1/SAMPLING_RATE, DURATION, SAMPLING_RATE*DURATION)
149
+ print(f"len time axis : {len(time_axis)}")
150
+
151
+ def spooled_tempfile_to_string(spooled_tempfile) -> str:
152
+ spooled_tempfile.seek(0)
153
+
154
+ raw_content = spooled_tempfile.read()
155
+
156
+ if isinstance(raw_content, bytes):
157
+ content = raw_content.decode('utf-8')
158
+
159
+ return content
160
+
161
+ def encode(raw_content) :
162
+ content = raw_content[raw_content.find('[') + 1 : raw_content.rfind(']')]
163
+
164
+ data_list = []
165
+
166
+ i = 0
167
+
168
+ while i < len(content) :
169
+ idx_open = content.find('[', i)
170
+ idx_close = content.find(']', idx_open)
171
+
172
+ if idx_open==-1 or idx_close==-1 : break
173
+
174
+ txt = content[idx_open+1 : idx_close]
175
+ row = [float(val) for val in txt.split(',')]
176
+
177
+ data_list.append(row)
178
+
179
+ i = idx_close + 1
180
+
181
+ df = pd.DataFrame(data_list, columns=['roll','pitch','yaw'], index=time_axis)
182
+
183
+ return df
184
+
185
+ class Fourier :
186
+ def __init__(self, signal, sampling_rate, duration) :
187
+ self.signal = np.array(signal)
188
+ self.sampling_rate = sampling_rate
189
+ self.duration = duration
190
+
191
+ self.time = np.linspace(1/sampling_rate, duration, len(signal))
192
+
193
+ self.freq = self.get_frequencies()
194
+ self.raw_amplitudes = self.get_amplitudes()
195
+ self.amplitudes = self.norm_amplitudes()
196
+
197
+ def get_frequencies(self) :
198
+ return rfftfreq(len(self.signal), 1/self.sampling_rate)
199
+
200
+ def get_amplitudes(self) :
201
+ return rfft(self.signal)
202
+
203
+ def norm_amplitudes(self) :
204
+ return 2*np.abs(self.raw_amplitudes) / len(self.signal)
205
+
206
+ def predict_tremor(signal, sampling_rate, duration) :
207
+ fourier = Fourier(signal, sampling_rate, duration)
208
+ amp = fourier.amplitudes
209
+
210
+ start_freq = 0.5
211
+ start_idx = np.argwhere(fourier.freq >= start_freq)[0][0]
212
+
213
+ idx_3hz = np.argwhere(fourier.freq >= 3)[0][0]
214
+ idx_6hz = np.argwhere(fourier.freq >= 6)[0][0]
215
+
216
+ max_amp_parkinson_range = np.max(amp[idx_3hz:idx_6hz+1])
217
+ if max_amp_parkinson_range > 1.5 : return 2
218
+
219
+ max_amp_tremor_range = np.max(amp[start_idx:])
220
+ if max_amp_tremor_range > 1 : return 1
221
+
222
+ return 0
223
+
224
+ @app.post("/predict_shake/")
225
+ async def predict_shake(file: UploadFile = File(...)):
226
+ raw_content = spooled_tempfile_to_string(file.file)
227
+
228
+ print(f"len raw_content : {len(raw_content)}")
229
+ print(f"raw_content: {raw_content[:10]}")
230
+
231
+ df = encode(raw_content)
232
+ print(f"df shape : {df.shape}")
233
+
234
+ level_roll = predict_tremor(df['roll'], SAMPLING_RATE, DURATION)
235
+ level_pitch = predict_tremor(df['pitch'], SAMPLING_RATE, DURATION)
236
+ level_yaw = predict_tremor(df['yaw'], SAMPLING_RATE, DURATION)
237
+
238
+ print(f"roll: {level_roll}, pitch: {level_pitch}, yaw: {level_yaw}")
239
+
240
+ highest_level = np.max([level_roll, level_pitch, level_yaw])
241
+
242
+ level_names = ['low', 'mid', 'high']
243
+
244
+ curr_time = datetime.datetime.now()
245
+ print(f"curr time : {curr_time}")
246
+
247
+ return level_names[highest_level]
248
+
249
  if __name__ == "__main__":
250
  import uvicorn
251
  uvicorn.run(app, host="0.0.0.0", port=8000)