from fastapi import FastAPI, File, UploadFile, HTTPException, status from fastapi.responses import JSONResponse, FileResponse from contextlib import asynccontextmanager from pathlib import Path import tensorflow as tf import numpy as np import os import shutil import cv2 import logging import uuid import yara import asyncio import re from huggingface_hub import snapshot_download from typing import Optional, List, Dict import aiofiles from fastapi.concurrency import run_in_threadpool # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) MAL_CLASSES = ['Adialer.C', 'Agent.FYI', 'Allaple.A', 'Allaple.L', 'Alueron.gen!J', 'Autorun.K', 'C2LOP.P', 'C2LOP.gen!g', 'Dialplatform.B', 'Dontovo.A', 'Fakerean', 'Instantaccess', 'Lolyda.AA1', 'Lolyda.AA2', 'Lolyda.AA3', 'Lolyda.AT', 'Malex.gen!J', 'Obfuscator.AD', 'Rbot!gen', 'Skintrim.N', 'Swizzor.gen!E', 'Swizzor.gen!I', 'VB.AT', 'Wintrim.BX', 'Yuner.A'] UPLOAD_DIR = "uploads" os.makedirs(UPLOAD_DIR, exist_ok=True) YARA_REPO_URL = "https://github.com/t-tani/defender2yara.git" YARA_REPO_BRANCH = "yara-rules" YARA_REPO_DIR = "defenderyara" # Environment configuration MODEL_REPO = os.getenv("MODEL_REPO", "GranularFireplace/malware") MODEL_FILE = os.getenv("MODEL_FILE", "model_v2_with_weight.keras") async def clone_yara_repo(): """Clone YARA rules repository asynchronously""" try: repo_path = Path(YARA_REPO_DIR) # Remove existing repository if it exists if repo_path.exists(): logger.info("Removing existing YARA rules repository") shutil.rmtree(repo_path) logger.info(f"Cloning YARA rules from {YARA_REPO_URL}") proc = await asyncio.create_subprocess_exec( 'git', 'clone', '-b', YARA_REPO_BRANCH, '--single-branch', YARA_REPO_URL, str(repo_path), stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE ) stdout, stderr = await proc.communicate() if proc.returncode != 0: logger.error(f"Failed to clone YARA repo: {stderr.decode()}") return None logger.info("YARA rules repository cloned successfully") return repo_path except Exception as e: logger.error(f"Error cloning YARA repository: {str(e)}") return None def preprocess_yara_rules(repo_path: Path) -> Path: """Preprocess YARA rules to fix syntax issues and ensure unique rule names""" processed_dir = Path("processed_yara_rules") if processed_dir.exists(): shutil.rmtree(processed_dir) processed_dir.mkdir() rule_pattern = re.compile(r"rule\s+([^\s{]+)") seen_rules = {} rule_counter = 0 for yara_file in repo_path.glob("**/*.yara"): if yara_file.name == "misc.yara": logger.info(f"Skipping {yara_file} as it does not belong to any malware family") continue # Preserve directory structure try: relative_path = yara_file.relative_to(repo_path) except ValueError: continue # Skip files not in repo path processed_file = processed_dir / relative_path processed_file.parent.mkdir(parents=True, exist_ok=True) new_content = [] with open(yara_file, "r", encoding="utf-8", errors='replace') as f: for line in f: line = line.rstrip('\n') match = rule_pattern.match(line.strip()) if match: original_name = match.group(1) # Sanitize rule name clean_name = re.sub(r'[^a-zA-Z0-9_]', '_', original_name) # Ensure valid starting character if not clean_name: clean_name = "invalid_rule" elif not clean_name[0].isalpha() and clean_name[0] != '_': clean_name = f"_{clean_name}" # Handle duplicates if clean_name in seen_rules: seen_rules[clean_name] += 1 new_name = f"{clean_name}_{seen_rules[clean_name]}" line = line.replace(original_name, new_name, 1) rule_counter += 1 logger.debug(f"Renamed duplicate rule: {original_name} -> {new_name}") else: seen_rules[clean_name] = 0 # Initialize count if clean_name != original_name: line = line.replace(original_name, clean_name, 1) logger.debug(f"Sanitized rule name: {original_name} -> {clean_name}") new_content.append(line + '\n') with open(processed_file, "w", encoding="utf-8") as f: f.writelines(new_content) logger.info(f"Processed {rule_counter} duplicate rules") return processed_dir def compile_yara_rules(repo_path: Path) -> Optional[yara.Rules]: """Compile YARA rules from repository with error handling""" try: processed_dir = preprocess_yara_rules(repo_path) yara_files = list(processed_dir.glob("**/*.yara")) if not yara_files: logger.warning("No YARA files found in repository") return None logger.info(f"Found {len(yara_files)} YARA files, compiling rules") rules = {} for i, yara_file in enumerate(yara_files): try: rules[str(f"{yara_file}_{i}")] = str(yara_file) except Exception as e: logger.warning(f"Error processing {yara_file}: {str(e)}") return yara.compile(filepaths=rules) except yara.SyntaxError as e: logger.error(f"YARA syntax error: {str(e)}") return None except Exception as e: logger.error(f"Error compiling YARA rules: {str(e)}") return None @asynccontextmanager async def lifespan(app: FastAPI): """Manage application lifecycle""" # Initialize app state app.state.model = None app.state.yara_rules = None try: # Load ML model logger.info("Downloading model from Hugging Face Hub...") download_dir = snapshot_download(MODEL_REPO) app.state.model = tf.keras.models.load_model(os.path.join(download_dir, MODEL_FILE)) logger.info("Model loaded successfully") # Clone and compile YARA rules yara_repo_path = await clone_yara_repo() if yara_repo_path: app.state.yara_rules = compile_yara_rules(yara_repo_path) if app.state.yara_rules: logger.info("YARA rules compiled successfully") else: logger.warning("No valid YARA rules compiled") else: logger.warning("YARA rules unavailable") except Exception as e: logger.error(f"Initialization error: {str(e)}") raise yield # Cleanup app.state.model = None app.state.yara_rules = None app = FastAPI(lifespan=lifespan) @app.get("/") async def greet_json(): return {"Hello": "World!"} @app.post("/upload", status_code=status.HTTP_201_CREATED) async def upload_file(file: UploadFile = File(...)): """Handle file uploads with async operations and path sanitization""" try: # Sanitize filename to prevent path traversal filename = Path(file.filename).name file_path = os.path.join(UPLOAD_DIR, filename) async with aiofiles.open(file_path, "wb") as buffer: content = await file.read() await buffer.write(content) logger.info(f"File uploaded successfully: {filename}") return {"filename": filename, "message": "File uploaded successfully"} except Exception as e: logger.error(f"Error uploading file: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error uploading file: {str(e)}" ) @app.get("/files") async def list_files(): """List all uploaded files""" try: files = os.listdir(UPLOAD_DIR) filtered_files = [file for file in files if not file.startswith('temp')] return {"files": filtered_files} except Exception as e: logger.error(f"Error listing files: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error listing files: {str(e)}" ) @app.get("/download/{file_name}") async def download_file(file_name: str): """Serve files for download with proper security checks""" # Sanitize input filename sanitized_name = Path(file_name).name file_path = os.path.join(UPLOAD_DIR, sanitized_name) if not os.path.exists(file_path): logger.warning(f"File not found: {sanitized_name}") raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="File not found" ) return FileResponse(file_path, filename=sanitized_name) def softmax(x): exp_x = np.exp(x - np.max(x)) # Subtract max for numerical stability return exp_x / np.sum(exp_x) def predict_malware(img_array: np.ndarray) -> str: """Make prediction using the preloaded model""" try: prediction = app.state.model.predict(img_array) probabilities = softmax(prediction).flatten().tolist() return { "result": MAL_CLASSES[np.argmax(prediction)], "all_results": dict(zip(MAL_CLASSES, probabilities)) } except Exception as e: logger.error(f"Prediction error: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Error processing prediction" ) async def process_image(file_path: str, target_size: tuple = (64, 64)) -> np.ndarray: """Process image file for model input""" try: img = tf.keras.utils.load_img(file_path, target_size=target_size, color_mode="grayscale") img_array = tf.keras.utils.img_to_array(img) return tf.expand_dims(img_array, axis=0) except Exception as e: logger.error(f"Image processing error: {str(e)}") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid image file format" ) # @app.get("/analyse/{file_name}") # async def analyse(file_name: str): # """Analyze image files""" # sanitized_name = Path(file_name).name # file_path = os.path.join(UPLOAD_DIR, sanitized_name) # if not os.path.exists(file_path): # raise HTTPException( # status_code=status.HTTP_404_NOT_FOUND, # detail="File not found" # ) # try: # img_array = await process_image(file_path) # result = predict_malware(img_array) # return result # except HTTPException as he: # raise he # except Exception as e: # logger.error(f"Analysis error: {str(e)}") # raise HTTPException( # status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, # detail=f"Analysis failed: {str(e)}" # ) def get_image_width(file_path: str) -> int: """Determine image width based on file size""" file_size_kb = os.path.getsize(file_path) / 1024 size_ranges = [ (0, 10, 32), (10, 30, 64), (30, 60, 128), (60, 100, 256), (100, 200, 384), (200, 500, 512), (500, 1000, 768), (1000, float('inf'), 1024) ] for lower, upper, width in size_ranges: if lower <= file_size_kb < upper: return width raise ValueError("File size out of expected range") def convert_binary_to_image(binary_path: str, output_path: str, width: int): """Convert binary file to grayscale image with error handling""" try: with open(binary_path, "rb") as f: binary_data = f.read() grayscale = np.frombuffer(binary_data, dtype=np.uint8) height = len(grayscale) // width grayscale = grayscale[:height*width].reshape((height, width)) cv2.imwrite(output_path, grayscale) logger.debug(f"Image converted: {output_path}") except Exception as e: logger.error(f"Binary conversion error: {str(e)}") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid binary file format" ) @app.get("/image/{file_name}") async def image(file_name: str): if os.path.exists(os.path.join(UPLOAD_DIR, file_name)): return FileResponse(os.path.join(UPLOAD_DIR, file_name), media_type="image/png") return JSONResponse(content={"error": "Image not found"}) @app.get("/analysebin/{file_name}") async def analyse_bin(file_name: str): """Analyze binary files by converting to images""" sanitized_name = Path(file_name).name file_path = os.path.join(UPLOAD_DIR, sanitized_name) if not os.path.exists(file_path): raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="File not found" ) temp_image = os.path.join(UPLOAD_DIR, f"temp_{uuid.uuid4()}.png") try: width = get_image_width(file_path) convert_binary_to_image(file_path, temp_image, width) img_array = await process_image(temp_image) result = predict_malware(img_array) return { **result, "image_location": str(temp_image)[8:] } except HTTPException as he: raise he except Exception as e: logger.error(f"Binary analysis error: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Binary analysis failed: {str(e)}" ) @app.get("/analyse/yara/{file_name}") async def analyse_yara(file_name: str): """Analyze file using YARA rules from the GitHub repository""" sanitized_name = Path(file_name).name file_path = os.path.join(UPLOAD_DIR, sanitized_name) if not os.path.exists(file_path): raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="File not found" ) if not app.state.yara_rules: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="YARA rules not available" ) try: matches = await run_in_threadpool( app.state.yara_rules.match, file_path ) # matches = app.state.yara_rules.match(file_path) if matches: result = {"result": "Found", "matches": [{ "rule": match.rule, "namespace": match.namespace, "tags": match.tags, "meta": match.meta, # "strings": [s for s in match.strings] } for match in matches] } return result else: return {"result": "Does not match"} except Exception as e: logger.error(f"YARA analysis failed: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="YARA analysis failed" )