File size: 15,460 Bytes
b6c69a9
 
 
 
47172d4
 
b0eb377
bf59308
8ab568a
b6c69a9
 
53debb8
 
42da951
b6c69a9
53debb8
b6c69a9
53debb8
b6c69a9
 
 
 
 
 
 
 
 
 
b0eb377
a4a06e0
 
7cc399a
 
53debb8
41fb458
b6c69a9
 
 
 
53debb8
 
 
 
 
 
 
 
 
 
 
 
7cc399a
53debb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a2ea2f
2cb2d02
3a2ea2f
 
 
 
 
8c62a0b
 
 
2cb2d02
42da951
5d22c48
 
 
8c62a0b
 
 
 
 
 
 
 
 
 
ea54715
8c62a0b
ea54715
8c62a0b
 
 
ea54715
 
8c62a0b
 
 
 
 
 
 
 
 
 
 
 
 
ea54715
8c62a0b
 
ea54715
8c62a0b
 
 
 
 
 
 
ea54715
 
8c62a0b
 
3a2ea2f
 
53debb8
3a2ea2f
53debb8
305adbb
 
3a2ea2f
53debb8
576ef61
53debb8
 
576ef61
53debb8
 
2cb2d02
53debb8
2cb2d02
53debb8
576ef61
53debb8
576ef61
 
 
 
 
3a2ea2f
53debb8
 
3a2ea2f
53debb8
b6c69a9
 
53debb8
 
 
 
 
b6c69a9
53debb8
b6c69a9
 
 
 
53debb8
 
 
 
 
 
 
 
 
 
 
 
b6c69a9
53debb8
b6c69a9
53debb8
b6c69a9
53debb8
 
b6c69a9
53debb8
b6c69a9
 
41fb458
 
b6c69a9
41fb458
47172d4
b6c69a9
 
 
47172d4
b6c69a9
 
 
 
 
 
 
 
 
 
 
47172d4
b6c69a9
 
 
 
 
47172d4
 
b6c69a9
 
47172d4
 
27d9f2a
 
47172d4
b6c69a9
 
 
 
 
47172d4
 
b6c69a9
 
 
 
 
b0eb377
b6c69a9
 
 
 
 
 
b0eb377
b6c69a9
8ab568a
8d18eb0
 
 
 
b6c69a9
 
 
 
f4a91cd
8d18eb0
 
f4a91cd
8d18eb0
b6c69a9
 
 
 
 
8ab568a
 
b6c69a9
 
 
 
 
 
 
 
 
 
 
 
8ab568a
eb27cf6
 
 
 
 
b6c69a9
eb27cf6
 
 
 
 
b6c69a9
eb27cf6
 
 
 
 
 
 
 
 
 
 
 
8ab568a
b6c69a9
 
 
dcaabc8
b6c69a9
dcaabc8
 
 
 
 
 
 
 
 
 
b6c69a9
 
dcaabc8
 
b6c69a9
dcaabc8
b6c69a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ab568a
00e11d1
eb27cf6
27d9f2a
572f65a
eb27cf6
 
b6c69a9
 
 
 
 
8ab568a
b6c69a9
 
 
 
 
 
 
8ab568a
b6c69a9
 
 
 
 
87891c6
c1f7192
27d9f2a
87891c6
b6c69a9
 
 
 
 
 
 
 
 
9885f25
8462688
42a221c
9885f25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42a221c
 
 
 
 
62f0c6e
42a221c
62f0c6e
 
 
 
 
42a221c
62f0c6e
 
42a221c
62f0c6e
 
9885f25
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
from fastapi import FastAPI, File, UploadFile, HTTPException, status
from fastapi.responses import JSONResponse, FileResponse
from contextlib import asynccontextmanager
from pathlib import Path
import tensorflow as tf
import numpy as np
import os
import shutil
import cv2
import logging
import uuid
import yara
import asyncio
import re
from huggingface_hub import snapshot_download
from typing import Optional, List, Dict
import aiofiles
from fastapi.concurrency import run_in_threadpool

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

MAL_CLASSES = ['Adialer.C', 'Agent.FYI', 'Allaple.A', 'Allaple.L', 'Alueron.gen!J', 
               'Autorun.K', 'C2LOP.P', 'C2LOP.gen!g', 'Dialplatform.B', 'Dontovo.A', 
               'Fakerean', 'Instantaccess', 'Lolyda.AA1', 'Lolyda.AA2', 'Lolyda.AA3', 
               'Lolyda.AT', 'Malex.gen!J', 'Obfuscator.AD', 'Rbot!gen', 'Skintrim.N', 
               'Swizzor.gen!E', 'Swizzor.gen!I', 'VB.AT', 'Wintrim.BX', 'Yuner.A']

UPLOAD_DIR = "uploads"
os.makedirs(UPLOAD_DIR, exist_ok=True)
YARA_REPO_URL = "https://github.com/t-tani/defender2yara.git"
YARA_REPO_BRANCH = "yara-rules"
YARA_REPO_DIR = "defenderyara"

# Environment configuration
MODEL_REPO = os.getenv("MODEL_REPO", "GranularFireplace/malware")
MODEL_FILE = os.getenv("MODEL_FILE", "model_v2_with_weight.keras")

async def clone_yara_repo():
    """Clone YARA rules repository asynchronously"""
    try:
        repo_path = Path(YARA_REPO_DIR)
        
        # Remove existing repository if it exists
        if repo_path.exists():
            logger.info("Removing existing YARA rules repository")
            shutil.rmtree(repo_path)

        logger.info(f"Cloning YARA rules from {YARA_REPO_URL}")
        proc = await asyncio.create_subprocess_exec(
            'git', 'clone', '-b', YARA_REPO_BRANCH, '--single-branch', YARA_REPO_URL, str(repo_path),
            stdout=asyncio.subprocess.PIPE,
            stderr=asyncio.subprocess.PIPE
        )

        stdout, stderr = await proc.communicate()

        if proc.returncode != 0:
            logger.error(f"Failed to clone YARA repo: {stderr.decode()}")
            return None
        
        logger.info("YARA rules repository cloned successfully")
        return repo_path

    except Exception as e:
        logger.error(f"Error cloning YARA repository: {str(e)}")
        return None

def preprocess_yara_rules(repo_path: Path) -> Path:
    """Preprocess YARA rules to fix syntax issues and ensure unique rule names"""
    processed_dir = Path("processed_yara_rules")
    if processed_dir.exists():
        shutil.rmtree(processed_dir)
    processed_dir.mkdir()

    rule_pattern = re.compile(r"rule\s+([^\s{]+)")
    seen_rules = {}
    rule_counter = 0

    for yara_file in repo_path.glob("**/*.yara"):
        if yara_file.name == "misc.yara":
            logger.info(f"Skipping {yara_file} as it does not belong to any malware family")
            continue

        # Preserve directory structure
        try:
            relative_path = yara_file.relative_to(repo_path)
        except ValueError:
            continue  # Skip files not in repo path

        processed_file = processed_dir / relative_path
        processed_file.parent.mkdir(parents=True, exist_ok=True)

        new_content = []
        with open(yara_file, "r", encoding="utf-8", errors='replace') as f:
            for line in f:
                line = line.rstrip('\n')
                match = rule_pattern.match(line.strip())
                
                if match:
                    original_name = match.group(1)
                    # Sanitize rule name
                    clean_name = re.sub(r'[^a-zA-Z0-9_]', '_', original_name)
                    
                    # Ensure valid starting character
                    if not clean_name:
                        clean_name = "invalid_rule"
                    elif not clean_name[0].isalpha() and clean_name[0] != '_':
                        clean_name = f"_{clean_name}"

                    # Handle duplicates
                    if clean_name in seen_rules:
                        seen_rules[clean_name] += 1
                        new_name = f"{clean_name}_{seen_rules[clean_name]}"
                        line = line.replace(original_name, new_name, 1)
                        rule_counter += 1
                        logger.debug(f"Renamed duplicate rule: {original_name} -> {new_name}")
                    else:
                        seen_rules[clean_name] = 0  # Initialize count
                        if clean_name != original_name:
                            line = line.replace(original_name, clean_name, 1)
                            logger.debug(f"Sanitized rule name: {original_name} -> {clean_name}")

                new_content.append(line + '\n')

        with open(processed_file, "w", encoding="utf-8") as f:
            f.writelines(new_content)

    logger.info(f"Processed {rule_counter} duplicate rules")
    return processed_dir

def compile_yara_rules(repo_path: Path) -> Optional[yara.Rules]:
    """Compile YARA rules from repository with error handling"""
    try:
        processed_dir = preprocess_yara_rules(repo_path)
        yara_files = list(processed_dir.glob("**/*.yara"))
        
        if not yara_files:
            logger.warning("No YARA files found in repository")
            return None

        logger.info(f"Found {len(yara_files)} YARA files, compiling rules")
        rules = {}
        
        for i, yara_file in enumerate(yara_files):
            try:
                rules[str(f"{yara_file}_{i}")] = str(yara_file)
            except Exception as e:
                logger.warning(f"Error processing {yara_file}: {str(e)}")

        return yara.compile(filepaths=rules)
    
    except yara.SyntaxError as e:
        logger.error(f"YARA syntax error: {str(e)}")
        return None
            
    except Exception as e:
        logger.error(f"Error compiling YARA rules: {str(e)}")
        return None

@asynccontextmanager
async def lifespan(app: FastAPI):
    """Manage application lifecycle"""
    # Initialize app state
    app.state.model = None
    app.state.yara_rules = None

    try:
        # Load ML model
        logger.info("Downloading model from Hugging Face Hub...")
        download_dir = snapshot_download(MODEL_REPO)
        app.state.model = tf.keras.models.load_model(os.path.join(download_dir, MODEL_FILE))
        logger.info("Model loaded successfully")

        # Clone and compile YARA rules
        yara_repo_path = await clone_yara_repo()
        if yara_repo_path:
            app.state.yara_rules = compile_yara_rules(yara_repo_path)
            if app.state.yara_rules:
                logger.info("YARA rules compiled successfully")
            else:
                logger.warning("No valid YARA rules compiled")
        else:
            logger.warning("YARA rules unavailable")

    except Exception as e:
        logger.error(f"Initialization error: {str(e)}")
        raise

    yield

    # Cleanup
    app.state.model = None
    app.state.yara_rules = None

app = FastAPI(lifespan=lifespan)

@app.get("/")
async def greet_json():
    return {"Hello": "World!"}

@app.post("/upload", status_code=status.HTTP_201_CREATED)
async def upload_file(file: UploadFile = File(...)):
    """Handle file uploads with async operations and path sanitization"""
    try:
        # Sanitize filename to prevent path traversal
        filename = Path(file.filename).name
        file_path = os.path.join(UPLOAD_DIR, filename)
        
        async with aiofiles.open(file_path, "wb") as buffer:
            content = await file.read()
            await buffer.write(content)
            
        logger.info(f"File uploaded successfully: {filename}")
        return {"filename": filename, "message": "File uploaded successfully"}
    
    except Exception as e:
        logger.error(f"Error uploading file: {str(e)}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=f"Error uploading file: {str(e)}"
        )

@app.get("/files")
async def list_files():
    """List all uploaded files"""
    try:
        files = os.listdir(UPLOAD_DIR)
        filtered_files = [file for file in files if not file.startswith('temp')]
        return {"files": filtered_files}
    except Exception as e:
        logger.error(f"Error listing files: {str(e)}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=f"Error listing files: {str(e)}"
        )

@app.get("/download/{file_name}")
async def download_file(file_name: str):
    """Serve files for download with proper security checks"""
    # Sanitize input filename
    sanitized_name = Path(file_name).name
    file_path = os.path.join(UPLOAD_DIR, sanitized_name)
    
    if not os.path.exists(file_path):
        logger.warning(f"File not found: {sanitized_name}")
        raise HTTPException(
            status_code=status.HTTP_404_NOT_FOUND,
            detail="File not found"
        )
    
    return FileResponse(file_path, filename=sanitized_name)

def softmax(x):
    exp_x = np.exp(x - np.max(x))  # Subtract max for numerical stability
    return exp_x / np.sum(exp_x)

def predict_malware(img_array: np.ndarray) -> str:
    """Make prediction using the preloaded model"""
    try:
        prediction = app.state.model.predict(img_array)
        probabilities = softmax(prediction).flatten().tolist()
        return {
            "result": MAL_CLASSES[np.argmax(prediction)],
            "all_results": dict(zip(MAL_CLASSES, probabilities))
        }
    except Exception as e:
        logger.error(f"Prediction error: {str(e)}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail="Error processing prediction"
        )

async def process_image(file_path: str, target_size: tuple = (64, 64)) -> np.ndarray:
    """Process image file for model input"""
    try:
        img = tf.keras.utils.load_img(file_path, target_size=target_size, color_mode="grayscale")
        img_array = tf.keras.utils.img_to_array(img)
        return tf.expand_dims(img_array, axis=0)
    except Exception as e:
        logger.error(f"Image processing error: {str(e)}")
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail="Invalid image file format"
        )

# @app.get("/analyse/{file_name}")
# async def analyse(file_name: str):
#     """Analyze image files"""
#     sanitized_name = Path(file_name).name
#     file_path = os.path.join(UPLOAD_DIR, sanitized_name)
    
#     if not os.path.exists(file_path):
#         raise HTTPException(
#             status_code=status.HTTP_404_NOT_FOUND,
#             detail="File not found"
#         )
    
#     try:
#         img_array = await process_image(file_path)
#         result = predict_malware(img_array)
#         return result
#     except HTTPException as he:
#         raise he
#     except Exception as e:
#         logger.error(f"Analysis error: {str(e)}")
#         raise HTTPException(
#             status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
#             detail=f"Analysis failed: {str(e)}"
#         )

def get_image_width(file_path: str) -> int:
    """Determine image width based on file size"""
    file_size_kb = os.path.getsize(file_path) / 1024

    size_ranges = [
        (0, 10, 32),
        (10, 30, 64),
        (30, 60, 128),
        (60, 100, 256),
        (100, 200, 384),
        (200, 500, 512),
        (500, 1000, 768),
        (1000, float('inf'), 1024)
    ]

    for lower, upper, width in size_ranges:
        if lower <= file_size_kb < upper:
            return width

    raise ValueError("File size out of expected range")

def convert_binary_to_image(binary_path: str, output_path: str, width: int):
    """Convert binary file to grayscale image with error handling"""
    try:
        with open(binary_path, "rb") as f:
            binary_data = f.read()

        grayscale = np.frombuffer(binary_data, dtype=np.uint8)
        height = len(grayscale) // width
        grayscale = grayscale[:height*width].reshape((height, width))
        
        cv2.imwrite(output_path, grayscale)
        logger.debug(f"Image converted: {output_path}")
        
    except Exception as e:
        logger.error(f"Binary conversion error: {str(e)}")
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail="Invalid binary file format"
        )

@app.get("/image/{file_name}")
async def image(file_name: str):
    if os.path.exists(os.path.join(UPLOAD_DIR, file_name)):
        return FileResponse(os.path.join(UPLOAD_DIR, file_name), media_type="image/png")
    return JSONResponse(content={"error": "Image not found"})

@app.get("/analysebin/{file_name}")
async def analyse_bin(file_name: str):
    """Analyze binary files by converting to images"""
    sanitized_name = Path(file_name).name
    file_path = os.path.join(UPLOAD_DIR, sanitized_name)
    
    if not os.path.exists(file_path):
        raise HTTPException(
            status_code=status.HTTP_404_NOT_FOUND,
            detail="File not found"
        )

    temp_image = os.path.join(UPLOAD_DIR, f"temp_{uuid.uuid4()}.png")
    
    try:
        width = get_image_width(file_path)
        convert_binary_to_image(file_path, temp_image, width)
        img_array = await process_image(temp_image)
        result = predict_malware(img_array)
        return {
            **result,
            "image_location": str(temp_image)[8:]
        }
        
    except HTTPException as he:
        raise he
    except Exception as e:
        logger.error(f"Binary analysis error: {str(e)}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=f"Binary analysis failed: {str(e)}"
        )

@app.get("/analyse/yara/{file_name}")
async def analyse_yara(file_name: str):
    """Analyze file using YARA rules from the GitHub repository"""
    sanitized_name = Path(file_name).name
    file_path = os.path.join(UPLOAD_DIR, sanitized_name)
    
    if not os.path.exists(file_path):
        raise HTTPException(
            status_code=status.HTTP_404_NOT_FOUND,
            detail="File not found"
        )
    
    if not app.state.yara_rules:
        raise HTTPException(
            status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
            detail="YARA rules not available"
        )
    
    try:
        matches = await run_in_threadpool(
            app.state.yara_rules.match,
            file_path
        )
        # matches = app.state.yara_rules.match(file_path)
        if matches:
            result = {"result": "Found",
                    "matches": [{
                        "rule": match.rule,
                        "namespace": match.namespace,
                        "tags": match.tags,
                        "meta": match.meta,
                        # "strings": [s for s in match.strings]
                    } for match in matches]
                   }
            return result
        else:
            return {"result": "Does not match"}
    except Exception as e:
        logger.error(f"YARA analysis failed: {str(e)}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail="YARA analysis failed"
        )