GranularFireplace commited on
Commit
8d18eb0
·
verified ·
1 Parent(s): 42a221c

return other results

Browse files
Files changed (1) hide show
  1. app.py +10 -3
app.py CHANGED
@@ -251,11 +251,18 @@ async def download_file(file_name: str):
251
 
252
  return FileResponse(file_path, filename=sanitized_name)
253
 
 
 
 
 
254
  def predict_malware(img_array: np.ndarray) -> str:
255
  """Make prediction using the preloaded model"""
256
  try:
257
  prediction = app.state.model.predict(img_array)
258
- return MAL_CLASSES[np.argmax(prediction)]
 
 
 
259
  except Exception as e:
260
  logger.error(f"Prediction error: {str(e)}")
261
  raise HTTPException(
@@ -291,7 +298,7 @@ async def analyse(file_name: str):
291
  try:
292
  img_array = await process_image(file_path)
293
  result = predict_malware(img_array)
294
- return {"result": result}
295
  except HTTPException as he:
296
  raise he
297
  except Exception as e:
@@ -361,7 +368,7 @@ async def analyse_bin(file_name: str):
361
  convert_binary_to_image(file_path, temp_image, width)
362
  img_array = await process_image(temp_image)
363
  result = predict_malware(img_array)
364
- return {"result": result}
365
 
366
  except HTTPException as he:
367
  raise he
 
251
 
252
  return FileResponse(file_path, filename=sanitized_name)
253
 
254
+ def softmax(x):
255
+ exp_x = np.exp(x - np.max(x)) # Subtract max for numerical stability
256
+ return exp_x / np.sum(exp_x)
257
+
258
  def predict_malware(img_array: np.ndarray) -> str:
259
  """Make prediction using the preloaded model"""
260
  try:
261
  prediction = app.state.model.predict(img_array)
262
+ return {
263
+ "result": MAL_CLASSES[np.argmax(prediction)],
264
+ "all_results": dict(zip(keys, softmax(prediction)))
265
+ }
266
  except Exception as e:
267
  logger.error(f"Prediction error: {str(e)}")
268
  raise HTTPException(
 
298
  try:
299
  img_array = await process_image(file_path)
300
  result = predict_malware(img_array)
301
+ return result
302
  except HTTPException as he:
303
  raise he
304
  except Exception as e:
 
368
  convert_binary_to_image(file_path, temp_image, width)
369
  img_array = await process_image(temp_image)
370
  result = predict_malware(img_array)
371
+ return result
372
 
373
  except HTTPException as he:
374
  raise he