GranularFireplace commited on
Commit
b6c69a9
·
verified ·
1 Parent(s): dcaabc8

Prevent path traversal, refactor

Browse files
Files changed (1) hide show
  1. app.py +180 -88
app.py CHANGED
@@ -1,100 +1,162 @@
1
- from fastapi import FastAPI, File, UploadFile, HTTPException
2
- from fastapi.responses import JSONResponse
 
 
3
  import tensorflow as tf
4
- import tensorflow.keras as keras
5
  import numpy as np
6
  import os
7
  import shutil
8
- from huggingface_hub import snapshot_download
9
  import cv2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- MAL_CLASSES = ['Adialer.C', 'Agent.FYI', 'Allaple.A', 'Allaple.L', 'Alueron.gen!J', 'Autorun.K', 'C2LOP.P', 'C2LOP.gen!g', 'Dialplatform.B', 'Dontovo.A', 'Fakerean', 'Instantaccess', 'Lolyda.AA1', 'Lolyda.AA2', 'Lolyda.AA3', 'Lolyda.AT', 'Malex.gen!J', 'Obfuscator.AD', 'Rbot!gen', 'Skintrim.N', 'Swizzor.gen!E', 'Swizzor.gen!I', 'VB.AT', 'Wintrim.BX', 'Yuner.A']
12
  UPLOAD_DIR = "uploads"
13
  os.makedirs(UPLOAD_DIR, exist_ok=True)
14
 
15
- app = FastAPI()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  @app.get("/")
18
- def greet_json():
19
  return {"Hello": "World!"}
20
 
21
- @app.post("/upload")
22
- def upload_file(file: UploadFile = File(...)):
 
23
  try:
24
- # Save the uploaded file
25
- file_path = os.path.join(UPLOAD_DIR, file.filename)
26
- with open(file_path, "wb") as buffer:
27
- shutil.copyfileobj(file.file, buffer)
28
- return {"filename": file.filename, "message": "File uploaded successfully"}
 
 
 
 
 
 
29
  except Exception as e:
30
- raise HTTPException(status_code=500, detail=f"Error uploading file: {str(e)}")
 
 
 
 
31
 
32
  @app.get("/files")
33
- def list_files():
 
34
  try:
35
  files = os.listdir(UPLOAD_DIR)
36
  return {"files": files}
37
  except Exception as e:
38
- raise HTTPException(status_code=500, detail=f"Error listing files: {str(e)}")
 
 
 
 
39
 
40
  @app.get("/download/{file_name}")
41
- def download_file(file_name: str):
42
- file_path = os.path.join(UPLOAD_DIR, file_name)
43
- if os.path.exists(file_path):
44
- return JSONResponse(content={"message": "Download endpoint available. Use a client like cURL or a browser to download.", "file_path": file_path})
45
- else:
46
- raise HTTPException(status_code=404, detail="File not found")
47
-
48
- @app.get("/analyse/{file_name}")
49
- def analyse(file_name: str):
50
- download_dir = snapshot_download("GranularFireplace/malware")
51
- print(download_dir)
52
- print(os.path.join(download_dir, 'model_v2_with_weight.keras'))
53
- img = keras.preprocessing.image.load_img(
54
- os.path.join(UPLOAD_DIR, file_name), target_size=(64, 64)
55
- )
56
-
57
- img_array = keras.preprocessing.image.img_to_array(tf.image.rgb_to_grayscale(img))
58
-
59
- img_array = tf.expand_dims(img_array, 0) # Create a batch
60
 
61
- model = tf.keras.models.load_model(os.path.join(download_dir, 'model_v2_with_weight.keras'))
62
- predicted_label = model.predict(img_array)
 
 
 
 
63
 
64
- return {"result": MAL_CLASSES[np.argmax(predicted_label)]}
65
-
66
- def convert_binary_to_grayscale_image(binary_file_path, output_image_path, height=None, width=None):
67
- with open(binary_file_path, "rb") as bin_file:
68
- binary_data = bin_file.read()
69
-
70
- grayscale_image = np.frombuffer(binary_data, dtype=np.uint8)
71
 
72
- total_pixels = len(grayscale_image)
73
- if height is None and width is None:
74
- raise ValueError("Either height or width must be specified.")
75
- elif height is None:
76
- height = total_pixels // width
77
- elif width is None:
78
- width = total_pixels // height
79
-
80
- if height * width != total_pixels:
81
- raise ValueError(
82
- f"The binary file size ({total_pixels}) is not compatible with the specified dimensions ({height}x{width})."
83
  )
84
 
85
- grayscale_image = grayscale_image.reshape((height, width))
86
-
87
- # Save
88
- cv2.imwrite(output_image_path, grayscale_image)
 
 
 
 
 
 
 
 
89
 
90
- print(f"Grayscale image saved to {output_image_path} with dimensions ({height}, {width})")
91
- return grayscale_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
- def get_image_width(file_path):
94
- file_size_kb = os.path.getsize(file_path) / 1024 # File size in kilobytes
 
95
 
96
- # Define the file size range and corresponding image width
97
- file_size_ranges = [
98
  (0, 10, 32),
99
  (10, 30, 64),
100
  (30, 60, 128),
@@ -105,31 +167,61 @@ def get_image_width(file_path):
105
  (1000, float('inf'), 1024)
106
  ]
107
 
108
- # Determine the image width based on the file size
109
- for lower_bound, upper_bound, width in file_size_ranges:
110
- if lower_bound <= file_size_kb < upper_bound:
111
  return width
112
 
113
- # If no range matches (unlikely due to the final range), raise an error
114
- raise ValueError("File size is out of expected range.")
115
 
116
- @app.get("/analysebin/{file_name}")
117
- def analyse_bin(file_name: str):
118
- download_dir = snapshot_download("GranularFireplace/malware")
119
- print(download_dir)
120
- print(os.path.join(download_dir, 'model_v2_with_weight.keras'))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
- convert_binary_to_grayscale_image(os.path.join(UPLOAD_DIR, file_name), "image.png", width=get_image_width(os.path.join(UPLOAD_DIR, file_name)))
123
-
124
- img = keras.preprocessing.image.load_img(
125
- "image.png", target_size=(64, 64)
126
- )
127
-
128
- img_array = keras.preprocessing.image.img_to_array(tf.image.rgb_to_grayscale(img))
129
-
130
- img_array = tf.expand_dims(img_array, 0) # Create a batch
131
 
132
- model = tf.keras.models.load_model(os.path.join(download_dir, 'model_v2_with_weight.keras'))
133
- predicted_label = model.predict(img_array)
 
 
 
 
 
134
 
135
- return {"result": MAL_CLASSES[np.argmax(predicted_label)]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile, HTTPException, status
2
+ from fastapi.responses import JSONResponse, FileResponse
3
+ from contextlib import asynccontextmanager
4
+ from pathlib import Path
5
  import tensorflow as tf
 
6
  import numpy as np
7
  import os
8
  import shutil
 
9
  import cv2
10
+ import logging
11
+ import uuid
12
+ from huggingface_hub import snapshot_download
13
+ from typing import Optional
14
+ import aiofiles
15
+
16
+ # Configure logging
17
+ logging.basicConfig(level=logging.INFO)
18
+ logger = logging.getLogger(__name__)
19
+
20
+ MAL_CLASSES = ['Adialer.C', 'Agent.FYI', 'Allaple.A', 'Allaple.L', 'Alueron.gen!J',
21
+ 'Autorun.K', 'C2LOP.P', 'C2LOP.gen!g', 'Dialplatform.B', 'Dontovo.A',
22
+ 'Fakerean', 'Instantaccess', 'Lolyda.AA1', 'Lolyda.AA2', 'Lolyda.AA3',
23
+ 'Lolyda.AT', 'Malex.gen!J', 'Obfuscator.AD', 'Rbot!gen', 'Skintrim.N',
24
+ 'Swizzor.gen!E', 'Swizzor.gen!I', 'VB.AT', 'Wintrim.BX', 'Yuner.A']
25
 
 
26
  UPLOAD_DIR = "uploads"
27
  os.makedirs(UPLOAD_DIR, exist_ok=True)
28
 
29
+ # Environment configuration
30
+ MODEL_REPO = os.getenv("MODEL_REPO", "GranularFireplace/malware")
31
+ MODEL_FILE = os.getenv("MODEL_FILE", "model_v2_with_weight.keras")
32
+
33
+ @asynccontextmanager
34
+ async def lifespan(app: FastAPI):
35
+ """Manage model loading and unloading during app lifecycle"""
36
+ try:
37
+ logger.info("Downloading model from Hugging Face Hub...")
38
+ download_dir = snapshot_download(MODEL_REPO)
39
+ app.state.model = tf.keras.models.load_model(os.path.join(download_dir, MODEL_FILE))
40
+ logger.info("Model loaded successfully")
41
+ except Exception as e:
42
+ logger.error(f"Error loading model: {str(e)}")
43
+ raise
44
+ yield
45
+ # Cleanup resources if needed
46
+ app.state.model = None
47
+
48
+ app = FastAPI(lifespan=lifespan)
49
 
50
  @app.get("/")
51
+ async def greet_json():
52
  return {"Hello": "World!"}
53
 
54
+ @app.post("/upload", status_code=status.HTTP_201_CREATED)
55
+ async def upload_file(file: UploadFile = File(...)):
56
+ """Handle file uploads with async operations and path sanitization"""
57
  try:
58
+ # Sanitize filename to prevent path traversal
59
+ filename = Path(file.filename).name
60
+ file_path = os.path.join(UPLOAD_DIR, filename)
61
+
62
+ async with aiofiles.open(file_path, "wb") as buffer:
63
+ content = await file.read()
64
+ await buffer.write(content)
65
+
66
+ logger.info(f"File uploaded successfully: {filename}")
67
+ return {"filename": filename, "message": "File uploaded successfully"}
68
+
69
  except Exception as e:
70
+ logger.error(f"Error uploading file: {str(e)}")
71
+ raise HTTPException(
72
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
73
+ detail=f"Error uploading file: {str(e)}"
74
+ )
75
 
76
  @app.get("/files")
77
+ async def list_files():
78
+ """List all uploaded files"""
79
  try:
80
  files = os.listdir(UPLOAD_DIR)
81
  return {"files": files}
82
  except Exception as e:
83
+ logger.error(f"Error listing files: {str(e)}")
84
+ raise HTTPException(
85
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
86
+ detail=f"Error listing files: {str(e)}"
87
+ )
88
 
89
  @app.get("/download/{file_name}")
90
+ async def download_file(file_name: str):
91
+ """Serve files for download with proper security checks"""
92
+ # Sanitize input filename
93
+ sanitized_name = Path(file_name).name
94
+ file_path = os.path.join(UPLOAD_DIR, sanitized_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
+ if not os.path.exists(file_path):
97
+ logger.warning(f"File not found: {sanitized_name}")
98
+ raise HTTPException(
99
+ status_code=status.HTTP_404_NOT_FOUND,
100
+ detail="File not found"
101
+ )
102
 
103
+ return FileResponse(file_path, filename=sanitized_name)
 
 
 
 
 
 
104
 
105
+ def predict_malware(img_array: np.ndarray) -> str:
106
+ """Make prediction using the preloaded model"""
107
+ try:
108
+ prediction = app.state.model.predict(img_array)
109
+ return MAL_CLASSES[np.argmax(prediction)]
110
+ except Exception as e:
111
+ logger.error(f"Prediction error: {str(e)}")
112
+ raise HTTPException(
113
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
114
+ detail="Error processing prediction"
 
115
  )
116
 
117
+ async def process_image(file_path: str, target_size: tuple = (64, 64)) -> np.ndarray:
118
+ """Process image file for model input"""
119
+ try:
120
+ img = tf.keras.utils.load_img(file_path, target_size=target_size, color_mode="grayscale")
121
+ img_array = tf.keras.utils.img_to_array(img)
122
+ return tf.expand_dims(img_array, axis=0)
123
+ except Exception as e:
124
+ logger.error(f"Image processing error: {str(e)}")
125
+ raise HTTPException(
126
+ status_code=status.HTTP_400_BAD_REQUEST,
127
+ detail="Invalid image file format"
128
+ )
129
 
130
+ @app.get("/analyse/{file_name}")
131
+ async def analyse(file_name: str):
132
+ """Analyze image files"""
133
+ sanitized_name = Path(file_name).name
134
+ file_path = os.path.join(UPLOAD_DIR, sanitized_name)
135
+
136
+ if not os.path.exists(file_path):
137
+ raise HTTPException(
138
+ status_code=status.HTTP_404_NOT_FOUND,
139
+ detail="File not found"
140
+ )
141
+
142
+ try:
143
+ img_array = await process_image(file_path)
144
+ result = predict_malware(img_array)
145
+ return {"result": result}
146
+ except HTTPException as he:
147
+ raise he
148
+ except Exception as e:
149
+ logger.error(f"Analysis error: {str(e)}")
150
+ raise HTTPException(
151
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
152
+ detail=f"Analysis failed: {str(e)}"
153
+ )
154
 
155
+ def get_image_width(file_path: str) -> int:
156
+ """Determine image width based on file size"""
157
+ file_size_kb = os.path.getsize(file_path) / 1024
158
 
159
+ size_ranges = [
 
160
  (0, 10, 32),
161
  (10, 30, 64),
162
  (30, 60, 128),
 
167
  (1000, float('inf'), 1024)
168
  ]
169
 
170
+ for lower, upper, width in size_ranges:
171
+ if lower <= file_size_kb < upper:
 
172
  return width
173
 
174
+ raise ValueError("File size out of expected range")
 
175
 
176
+ def convert_binary_to_image(binary_path: str, output_path: str, width: int):
177
+ """Convert binary file to grayscale image with error handling"""
178
+ try:
179
+ with open(binary_path, "rb") as f:
180
+ binary_data = f.read()
181
+
182
+ grayscale = np.frombuffer(binary_data, dtype=np.uint8)
183
+ height = len(grayscale) // width
184
+ grayscale = grayscale[:height*width].reshape((height, width))
185
+
186
+ cv2.imwrite(output_path, grayscale)
187
+ logger.debug(f"Image converted: {output_path}")
188
+
189
+ except Exception as e:
190
+ logger.error(f"Binary conversion error: {str(e)}")
191
+ raise HTTPException(
192
+ status_code=status.HTTP_400_BAD_REQUEST,
193
+ detail="Invalid binary file format"
194
+ )
195
 
196
+ @app.get("/analysebin/{file_name}")
197
+ async def analyse_bin(file_name: str):
198
+ """Analyze binary files by converting to images"""
199
+ sanitized_name = Path(file_name).name
200
+ file_path = os.path.join(UPLOAD_DIR, sanitized_name)
 
 
 
 
201
 
202
+ if not os.path.exists(file_path):
203
+ raise HTTPException(
204
+ status_code=status.HTTP_404_NOT_FOUND,
205
+ detail="File not found"
206
+ )
207
+
208
+ temp_image = os.path.join(UPLOAD_DIR, f"temp_{uuid.uuid4()}.png")
209
 
210
+ try:
211
+ width = get_image_width(file_path)
212
+ convert_binary_to_image(file_path, temp_image, width)
213
+ img_array = await process_image(temp_image)
214
+ result = predict_malware(img_array)
215
+ return {"result": result}
216
+
217
+ except HTTPException as he:
218
+ raise he
219
+ except Exception as e:
220
+ logger.error(f"Binary analysis error: {str(e)}")
221
+ raise HTTPException(
222
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
223
+ detail=f"Binary analysis failed: {str(e)}"
224
+ )
225
+ finally:
226
+ if os.path.exists(temp_image):
227
+ os.remove(temp_image)