GranularFireplace commited on
Commit
eb27cf6
·
verified ·
1 Parent(s): 8d18eb0

return link to image generated from ai

Browse files
Files changed (1) hide show
  1. app.py +30 -26
app.py CHANGED
@@ -261,7 +261,7 @@ def predict_malware(img_array: np.ndarray) -> str:
261
  prediction = app.state.model.predict(img_array)
262
  return {
263
  "result": MAL_CLASSES[np.argmax(prediction)],
264
- "all_results": dict(zip(keys, softmax(prediction)))
265
  }
266
  except Exception as e:
267
  logger.error(f"Prediction error: {str(e)}")
@@ -283,30 +283,30 @@ async def process_image(file_path: str, target_size: tuple = (64, 64)) -> np.nda
283
  detail="Invalid image file format"
284
  )
285
 
286
- @app.get("/analyse/{file_name}")
287
- async def analyse(file_name: str):
288
- """Analyze image files"""
289
- sanitized_name = Path(file_name).name
290
- file_path = os.path.join(UPLOAD_DIR, sanitized_name)
291
 
292
- if not os.path.exists(file_path):
293
- raise HTTPException(
294
- status_code=status.HTTP_404_NOT_FOUND,
295
- detail="File not found"
296
- )
297
 
298
- try:
299
- img_array = await process_image(file_path)
300
- result = predict_malware(img_array)
301
- return result
302
- except HTTPException as he:
303
- raise he
304
- except Exception as e:
305
- logger.error(f"Analysis error: {str(e)}")
306
- raise HTTPException(
307
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
308
- detail=f"Analysis failed: {str(e)}"
309
- )
310
 
311
  def get_image_width(file_path: str) -> int:
312
  """Determine image width based on file size"""
@@ -349,6 +349,12 @@ def convert_binary_to_image(binary_path: str, output_path: str, width: int):
349
  detail="Invalid binary file format"
350
  )
351
 
 
 
 
 
 
 
352
  @app.get("/analysebin/{file_name}")
353
  async def analyse_bin(file_name: str):
354
  """Analyze binary files by converting to images"""
@@ -368,6 +374,7 @@ async def analyse_bin(file_name: str):
368
  convert_binary_to_image(file_path, temp_image, width)
369
  img_array = await process_image(temp_image)
370
  result = predict_malware(img_array)
 
371
  return result
372
 
373
  except HTTPException as he:
@@ -378,9 +385,6 @@ async def analyse_bin(file_name: str):
378
  status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
379
  detail=f"Binary analysis failed: {str(e)}"
380
  )
381
- finally:
382
- if os.path.exists(temp_image):
383
- os.remove(temp_image)
384
 
385
  @app.get("/analyse/yara/{file_name}")
386
  async def analyse_yara(file_name: str):
 
261
  prediction = app.state.model.predict(img_array)
262
  return {
263
  "result": MAL_CLASSES[np.argmax(prediction)],
264
+ "all_results": dict(zip(MAL_CLASSES, softmax(prediction)))
265
  }
266
  except Exception as e:
267
  logger.error(f"Prediction error: {str(e)}")
 
283
  detail="Invalid image file format"
284
  )
285
 
286
+ # @app.get("/analyse/{file_name}")
287
+ # async def analyse(file_name: str):
288
+ # """Analyze image files"""
289
+ # sanitized_name = Path(file_name).name
290
+ # file_path = os.path.join(UPLOAD_DIR, sanitized_name)
291
 
292
+ # if not os.path.exists(file_path):
293
+ # raise HTTPException(
294
+ # status_code=status.HTTP_404_NOT_FOUND,
295
+ # detail="File not found"
296
+ # )
297
 
298
+ # try:
299
+ # img_array = await process_image(file_path)
300
+ # result = predict_malware(img_array)
301
+ # return result
302
+ # except HTTPException as he:
303
+ # raise he
304
+ # except Exception as e:
305
+ # logger.error(f"Analysis error: {str(e)}")
306
+ # raise HTTPException(
307
+ # status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
308
+ # detail=f"Analysis failed: {str(e)}"
309
+ # )
310
 
311
  def get_image_width(file_path: str) -> int:
312
  """Determine image width based on file size"""
 
349
  detail="Invalid binary file format"
350
  )
351
 
352
+ @app.get("/image")
353
+ async def image(file_name: str):
354
+ if os.path.exists(file_name):
355
+ return FileResponse(file_name, media_type="image/png")
356
+ return JSONResponse(content={"error": "Image not found"})
357
+
358
  @app.get("/analysebin/{file_name}")
359
  async def analyse_bin(file_name: str):
360
  """Analyze binary files by converting to images"""
 
374
  convert_binary_to_image(file_path, temp_image, width)
375
  img_array = await process_image(temp_image)
376
  result = predict_malware(img_array)
377
+ result['image'] = temp_image
378
  return result
379
 
380
  except HTTPException as he:
 
385
  status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
386
  detail=f"Binary analysis failed: {str(e)}"
387
  )
 
 
 
388
 
389
  @app.get("/analyse/yara/{file_name}")
390
  async def analyse_yara(file_name: str):