Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, File, UploadFile, HTTPException, status | |
| from fastapi.responses import JSONResponse, FileResponse | |
| from contextlib import asynccontextmanager | |
| from pathlib import Path | |
| import tensorflow as tf | |
| import numpy as np | |
| import os | |
| import shutil | |
| import cv2 | |
| import logging | |
| import uuid | |
| import yara | |
| import asyncio | |
| import re | |
| from huggingface_hub import snapshot_download | |
| from typing import Optional, List, Dict | |
| import aiofiles | |
| from fastapi.concurrency import run_in_threadpool | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| MAL_CLASSES = ['Adialer.C', 'Agent.FYI', 'Allaple.A', 'Allaple.L', 'Alueron.gen!J', | |
| 'Autorun.K', 'C2LOP.P', 'C2LOP.gen!g', 'Dialplatform.B', 'Dontovo.A', | |
| 'Fakerean', 'Instantaccess', 'Lolyda.AA1', 'Lolyda.AA2', 'Lolyda.AA3', | |
| 'Lolyda.AT', 'Malex.gen!J', 'Obfuscator.AD', 'Rbot!gen', 'Skintrim.N', | |
| 'Swizzor.gen!E', 'Swizzor.gen!I', 'VB.AT', 'Wintrim.BX', 'Yuner.A'] | |
| UPLOAD_DIR = "uploads" | |
| os.makedirs(UPLOAD_DIR, exist_ok=True) | |
| YARA_REPO_URL = "https://github.com/t-tani/defender2yara.git" | |
| YARA_REPO_BRANCH = "yara-rules" | |
| YARA_REPO_DIR = "defenderyara" | |
| # Environment configuration | |
| MODEL_REPO = os.getenv("MODEL_REPO", "GranularFireplace/malware") | |
| MODEL_FILE = os.getenv("MODEL_FILE", "model_v2_with_weight.keras") | |
| async def clone_yara_repo(): | |
| """Clone YARA rules repository asynchronously""" | |
| try: | |
| repo_path = Path(YARA_REPO_DIR) | |
| # Remove existing repository if it exists | |
| if repo_path.exists(): | |
| logger.info("Removing existing YARA rules repository") | |
| shutil.rmtree(repo_path) | |
| logger.info(f"Cloning YARA rules from {YARA_REPO_URL}") | |
| proc = await asyncio.create_subprocess_exec( | |
| 'git', 'clone', '-b', YARA_REPO_BRANCH, '--single-branch', YARA_REPO_URL, str(repo_path), | |
| stdout=asyncio.subprocess.PIPE, | |
| stderr=asyncio.subprocess.PIPE | |
| ) | |
| stdout, stderr = await proc.communicate() | |
| if proc.returncode != 0: | |
| logger.error(f"Failed to clone YARA repo: {stderr.decode()}") | |
| return None | |
| logger.info("YARA rules repository cloned successfully") | |
| return repo_path | |
| except Exception as e: | |
| logger.error(f"Error cloning YARA repository: {str(e)}") | |
| return None | |
| def preprocess_yara_rules(repo_path: Path) -> Path: | |
| """Preprocess YARA rules to fix syntax issues and ensure unique rule names""" | |
| processed_dir = Path("processed_yara_rules") | |
| if processed_dir.exists(): | |
| shutil.rmtree(processed_dir) | |
| processed_dir.mkdir() | |
| rule_pattern = re.compile(r"rule\s+([^\s{]+)") | |
| seen_rules = {} | |
| rule_counter = 0 | |
| for yara_file in repo_path.glob("**/*.yara"): | |
| if yara_file.name == "misc.yara": | |
| logger.info(f"Skipping {yara_file} as it does not belong to any malware family") | |
| continue | |
| # Preserve directory structure | |
| try: | |
| relative_path = yara_file.relative_to(repo_path) | |
| except ValueError: | |
| continue # Skip files not in repo path | |
| processed_file = processed_dir / relative_path | |
| processed_file.parent.mkdir(parents=True, exist_ok=True) | |
| new_content = [] | |
| with open(yara_file, "r", encoding="utf-8", errors='replace') as f: | |
| for line in f: | |
| line = line.rstrip('\n') | |
| match = rule_pattern.match(line.strip()) | |
| if match: | |
| original_name = match.group(1) | |
| # Sanitize rule name | |
| clean_name = re.sub(r'[^a-zA-Z0-9_]', '_', original_name) | |
| # Ensure valid starting character | |
| if not clean_name: | |
| clean_name = "invalid_rule" | |
| elif not clean_name[0].isalpha() and clean_name[0] != '_': | |
| clean_name = f"_{clean_name}" | |
| # Handle duplicates | |
| if clean_name in seen_rules: | |
| seen_rules[clean_name] += 1 | |
| new_name = f"{clean_name}_{seen_rules[clean_name]}" | |
| line = line.replace(original_name, new_name, 1) | |
| rule_counter += 1 | |
| logger.debug(f"Renamed duplicate rule: {original_name} -> {new_name}") | |
| else: | |
| seen_rules[clean_name] = 0 # Initialize count | |
| if clean_name != original_name: | |
| line = line.replace(original_name, clean_name, 1) | |
| logger.debug(f"Sanitized rule name: {original_name} -> {clean_name}") | |
| new_content.append(line + '\n') | |
| with open(processed_file, "w", encoding="utf-8") as f: | |
| f.writelines(new_content) | |
| logger.info(f"Processed {rule_counter} duplicate rules") | |
| return processed_dir | |
| def compile_yara_rules(repo_path: Path) -> Optional[yara.Rules]: | |
| """Compile YARA rules from repository with error handling""" | |
| try: | |
| processed_dir = preprocess_yara_rules(repo_path) | |
| yara_files = list(processed_dir.glob("**/*.yara")) | |
| if not yara_files: | |
| logger.warning("No YARA files found in repository") | |
| return None | |
| logger.info(f"Found {len(yara_files)} YARA files, compiling rules") | |
| rules = {} | |
| for i, yara_file in enumerate(yara_files): | |
| try: | |
| rules[str(f"{yara_file}_{i}")] = str(yara_file) | |
| except Exception as e: | |
| logger.warning(f"Error processing {yara_file}: {str(e)}") | |
| return yara.compile(filepaths=rules) | |
| except yara.SyntaxError as e: | |
| logger.error(f"YARA syntax error: {str(e)}") | |
| return None | |
| except Exception as e: | |
| logger.error(f"Error compiling YARA rules: {str(e)}") | |
| return None | |
| async def lifespan(app: FastAPI): | |
| """Manage application lifecycle""" | |
| # Initialize app state | |
| app.state.model = None | |
| app.state.yara_rules = None | |
| try: | |
| # Load ML model | |
| logger.info("Downloading model from Hugging Face Hub...") | |
| download_dir = snapshot_download(MODEL_REPO) | |
| app.state.model = tf.keras.models.load_model(os.path.join(download_dir, MODEL_FILE)) | |
| logger.info("Model loaded successfully") | |
| # Clone and compile YARA rules | |
| yara_repo_path = await clone_yara_repo() | |
| if yara_repo_path: | |
| app.state.yara_rules = compile_yara_rules(yara_repo_path) | |
| if app.state.yara_rules: | |
| logger.info("YARA rules compiled successfully") | |
| else: | |
| logger.warning("No valid YARA rules compiled") | |
| else: | |
| logger.warning("YARA rules unavailable") | |
| except Exception as e: | |
| logger.error(f"Initialization error: {str(e)}") | |
| raise | |
| yield | |
| # Cleanup | |
| app.state.model = None | |
| app.state.yara_rules = None | |
| app = FastAPI(lifespan=lifespan) | |
| async def greet_json(): | |
| return {"Hello": "World!"} | |
| async def upload_file(file: UploadFile = File(...)): | |
| """Handle file uploads with async operations and path sanitization""" | |
| try: | |
| # Sanitize filename to prevent path traversal | |
| filename = Path(file.filename).name | |
| file_path = os.path.join(UPLOAD_DIR, filename) | |
| async with aiofiles.open(file_path, "wb") as buffer: | |
| content = await file.read() | |
| await buffer.write(content) | |
| logger.info(f"File uploaded successfully: {filename}") | |
| return {"filename": filename, "message": "File uploaded successfully"} | |
| except Exception as e: | |
| logger.error(f"Error uploading file: {str(e)}") | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"Error uploading file: {str(e)}" | |
| ) | |
| async def list_files(): | |
| """List all uploaded files""" | |
| try: | |
| files = os.listdir(UPLOAD_DIR) | |
| filtered_files = [file for file in files if not file.startswith('temp')] | |
| return {"files": filtered_files} | |
| except Exception as e: | |
| logger.error(f"Error listing files: {str(e)}") | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"Error listing files: {str(e)}" | |
| ) | |
| async def download_file(file_name: str): | |
| """Serve files for download with proper security checks""" | |
| # Sanitize input filename | |
| sanitized_name = Path(file_name).name | |
| file_path = os.path.join(UPLOAD_DIR, sanitized_name) | |
| if not os.path.exists(file_path): | |
| logger.warning(f"File not found: {sanitized_name}") | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail="File not found" | |
| ) | |
| return FileResponse(file_path, filename=sanitized_name) | |
| def softmax(x): | |
| exp_x = np.exp(x - np.max(x)) # Subtract max for numerical stability | |
| return exp_x / np.sum(exp_x) | |
| def predict_malware(img_array: np.ndarray) -> str: | |
| """Make prediction using the preloaded model""" | |
| try: | |
| prediction = app.state.model.predict(img_array) | |
| probabilities = softmax(prediction).flatten().tolist() | |
| return { | |
| "result": MAL_CLASSES[np.argmax(prediction)], | |
| "all_results": dict(zip(MAL_CLASSES, probabilities)) | |
| } | |
| except Exception as e: | |
| logger.error(f"Prediction error: {str(e)}") | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail="Error processing prediction" | |
| ) | |
| async def process_image(file_path: str, target_size: tuple = (64, 64)) -> np.ndarray: | |
| """Process image file for model input""" | |
| try: | |
| img = tf.keras.utils.load_img(file_path, target_size=target_size, color_mode="grayscale") | |
| img_array = tf.keras.utils.img_to_array(img) | |
| return tf.expand_dims(img_array, axis=0) | |
| except Exception as e: | |
| logger.error(f"Image processing error: {str(e)}") | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="Invalid image file format" | |
| ) | |
| # @app.get("/analyse/{file_name}") | |
| # async def analyse(file_name: str): | |
| # """Analyze image files""" | |
| # sanitized_name = Path(file_name).name | |
| # file_path = os.path.join(UPLOAD_DIR, sanitized_name) | |
| # if not os.path.exists(file_path): | |
| # raise HTTPException( | |
| # status_code=status.HTTP_404_NOT_FOUND, | |
| # detail="File not found" | |
| # ) | |
| # try: | |
| # img_array = await process_image(file_path) | |
| # result = predict_malware(img_array) | |
| # return result | |
| # except HTTPException as he: | |
| # raise he | |
| # except Exception as e: | |
| # logger.error(f"Analysis error: {str(e)}") | |
| # raise HTTPException( | |
| # status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| # detail=f"Analysis failed: {str(e)}" | |
| # ) | |
| def get_image_width(file_path: str) -> int: | |
| """Determine image width based on file size""" | |
| file_size_kb = os.path.getsize(file_path) / 1024 | |
| size_ranges = [ | |
| (0, 10, 32), | |
| (10, 30, 64), | |
| (30, 60, 128), | |
| (60, 100, 256), | |
| (100, 200, 384), | |
| (200, 500, 512), | |
| (500, 1000, 768), | |
| (1000, float('inf'), 1024) | |
| ] | |
| for lower, upper, width in size_ranges: | |
| if lower <= file_size_kb < upper: | |
| return width | |
| raise ValueError("File size out of expected range") | |
| def convert_binary_to_image(binary_path: str, output_path: str, width: int): | |
| """Convert binary file to grayscale image with error handling""" | |
| try: | |
| with open(binary_path, "rb") as f: | |
| binary_data = f.read() | |
| grayscale = np.frombuffer(binary_data, dtype=np.uint8) | |
| height = len(grayscale) // width | |
| grayscale = grayscale[:height*width].reshape((height, width)) | |
| cv2.imwrite(output_path, grayscale) | |
| logger.debug(f"Image converted: {output_path}") | |
| except Exception as e: | |
| logger.error(f"Binary conversion error: {str(e)}") | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="Invalid binary file format" | |
| ) | |
| async def image(file_name: str): | |
| if os.path.exists(os.path.join(UPLOAD_DIR, file_name)): | |
| return FileResponse(os.path.join(UPLOAD_DIR, file_name), media_type="image/png") | |
| return JSONResponse(content={"error": "Image not found"}) | |
| async def analyse_bin(file_name: str): | |
| """Analyze binary files by converting to images""" | |
| sanitized_name = Path(file_name).name | |
| file_path = os.path.join(UPLOAD_DIR, sanitized_name) | |
| if not os.path.exists(file_path): | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail="File not found" | |
| ) | |
| temp_image = os.path.join(UPLOAD_DIR, f"temp_{uuid.uuid4()}.png") | |
| try: | |
| width = get_image_width(file_path) | |
| convert_binary_to_image(file_path, temp_image, width) | |
| img_array = await process_image(temp_image) | |
| result = predict_malware(img_array) | |
| return { | |
| **result, | |
| "image_location": str(temp_image)[8:] | |
| } | |
| except HTTPException as he: | |
| raise he | |
| except Exception as e: | |
| logger.error(f"Binary analysis error: {str(e)}") | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"Binary analysis failed: {str(e)}" | |
| ) | |
| async def analyse_yara(file_name: str): | |
| """Analyze file using YARA rules from the GitHub repository""" | |
| sanitized_name = Path(file_name).name | |
| file_path = os.path.join(UPLOAD_DIR, sanitized_name) | |
| if not os.path.exists(file_path): | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail="File not found" | |
| ) | |
| if not app.state.yara_rules: | |
| raise HTTPException( | |
| status_code=status.HTTP_503_SERVICE_UNAVAILABLE, | |
| detail="YARA rules not available" | |
| ) | |
| try: | |
| matches = await run_in_threadpool( | |
| app.state.yara_rules.match, | |
| file_path | |
| ) | |
| # matches = app.state.yara_rules.match(file_path) | |
| if matches: | |
| result = {"result": "Found", | |
| "matches": [{ | |
| "rule": match.rule, | |
| "namespace": match.namespace, | |
| "tags": match.tags, | |
| "meta": match.meta, | |
| # "strings": [s for s in match.strings] | |
| } for match in matches] | |
| } | |
| return result | |
| else: | |
| return {"result": "Does not match"} | |
| except Exception as e: | |
| logger.error(f"YARA analysis failed: {str(e)}") | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail="YARA analysis failed" | |
| ) |