mal / app.py
GranularFireplace's picture
img fix
572f65a verified
from fastapi import FastAPI, File, UploadFile, HTTPException, status
from fastapi.responses import JSONResponse, FileResponse
from contextlib import asynccontextmanager
from pathlib import Path
import tensorflow as tf
import numpy as np
import os
import shutil
import cv2
import logging
import uuid
import yara
import asyncio
import re
from huggingface_hub import snapshot_download
from typing import Optional, List, Dict
import aiofiles
from fastapi.concurrency import run_in_threadpool
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
MAL_CLASSES = ['Adialer.C', 'Agent.FYI', 'Allaple.A', 'Allaple.L', 'Alueron.gen!J',
'Autorun.K', 'C2LOP.P', 'C2LOP.gen!g', 'Dialplatform.B', 'Dontovo.A',
'Fakerean', 'Instantaccess', 'Lolyda.AA1', 'Lolyda.AA2', 'Lolyda.AA3',
'Lolyda.AT', 'Malex.gen!J', 'Obfuscator.AD', 'Rbot!gen', 'Skintrim.N',
'Swizzor.gen!E', 'Swizzor.gen!I', 'VB.AT', 'Wintrim.BX', 'Yuner.A']
UPLOAD_DIR = "uploads"
os.makedirs(UPLOAD_DIR, exist_ok=True)
YARA_REPO_URL = "https://github.com/t-tani/defender2yara.git"
YARA_REPO_BRANCH = "yara-rules"
YARA_REPO_DIR = "defenderyara"
# Environment configuration
MODEL_REPO = os.getenv("MODEL_REPO", "GranularFireplace/malware")
MODEL_FILE = os.getenv("MODEL_FILE", "model_v2_with_weight.keras")
async def clone_yara_repo():
"""Clone YARA rules repository asynchronously"""
try:
repo_path = Path(YARA_REPO_DIR)
# Remove existing repository if it exists
if repo_path.exists():
logger.info("Removing existing YARA rules repository")
shutil.rmtree(repo_path)
logger.info(f"Cloning YARA rules from {YARA_REPO_URL}")
proc = await asyncio.create_subprocess_exec(
'git', 'clone', '-b', YARA_REPO_BRANCH, '--single-branch', YARA_REPO_URL, str(repo_path),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
)
stdout, stderr = await proc.communicate()
if proc.returncode != 0:
logger.error(f"Failed to clone YARA repo: {stderr.decode()}")
return None
logger.info("YARA rules repository cloned successfully")
return repo_path
except Exception as e:
logger.error(f"Error cloning YARA repository: {str(e)}")
return None
def preprocess_yara_rules(repo_path: Path) -> Path:
"""Preprocess YARA rules to fix syntax issues and ensure unique rule names"""
processed_dir = Path("processed_yara_rules")
if processed_dir.exists():
shutil.rmtree(processed_dir)
processed_dir.mkdir()
rule_pattern = re.compile(r"rule\s+([^\s{]+)")
seen_rules = {}
rule_counter = 0
for yara_file in repo_path.glob("**/*.yara"):
if yara_file.name == "misc.yara":
logger.info(f"Skipping {yara_file} as it does not belong to any malware family")
continue
# Preserve directory structure
try:
relative_path = yara_file.relative_to(repo_path)
except ValueError:
continue # Skip files not in repo path
processed_file = processed_dir / relative_path
processed_file.parent.mkdir(parents=True, exist_ok=True)
new_content = []
with open(yara_file, "r", encoding="utf-8", errors='replace') as f:
for line in f:
line = line.rstrip('\n')
match = rule_pattern.match(line.strip())
if match:
original_name = match.group(1)
# Sanitize rule name
clean_name = re.sub(r'[^a-zA-Z0-9_]', '_', original_name)
# Ensure valid starting character
if not clean_name:
clean_name = "invalid_rule"
elif not clean_name[0].isalpha() and clean_name[0] != '_':
clean_name = f"_{clean_name}"
# Handle duplicates
if clean_name in seen_rules:
seen_rules[clean_name] += 1
new_name = f"{clean_name}_{seen_rules[clean_name]}"
line = line.replace(original_name, new_name, 1)
rule_counter += 1
logger.debug(f"Renamed duplicate rule: {original_name} -> {new_name}")
else:
seen_rules[clean_name] = 0 # Initialize count
if clean_name != original_name:
line = line.replace(original_name, clean_name, 1)
logger.debug(f"Sanitized rule name: {original_name} -> {clean_name}")
new_content.append(line + '\n')
with open(processed_file, "w", encoding="utf-8") as f:
f.writelines(new_content)
logger.info(f"Processed {rule_counter} duplicate rules")
return processed_dir
def compile_yara_rules(repo_path: Path) -> Optional[yara.Rules]:
"""Compile YARA rules from repository with error handling"""
try:
processed_dir = preprocess_yara_rules(repo_path)
yara_files = list(processed_dir.glob("**/*.yara"))
if not yara_files:
logger.warning("No YARA files found in repository")
return None
logger.info(f"Found {len(yara_files)} YARA files, compiling rules")
rules = {}
for i, yara_file in enumerate(yara_files):
try:
rules[str(f"{yara_file}_{i}")] = str(yara_file)
except Exception as e:
logger.warning(f"Error processing {yara_file}: {str(e)}")
return yara.compile(filepaths=rules)
except yara.SyntaxError as e:
logger.error(f"YARA syntax error: {str(e)}")
return None
except Exception as e:
logger.error(f"Error compiling YARA rules: {str(e)}")
return None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Manage application lifecycle"""
# Initialize app state
app.state.model = None
app.state.yara_rules = None
try:
# Load ML model
logger.info("Downloading model from Hugging Face Hub...")
download_dir = snapshot_download(MODEL_REPO)
app.state.model = tf.keras.models.load_model(os.path.join(download_dir, MODEL_FILE))
logger.info("Model loaded successfully")
# Clone and compile YARA rules
yara_repo_path = await clone_yara_repo()
if yara_repo_path:
app.state.yara_rules = compile_yara_rules(yara_repo_path)
if app.state.yara_rules:
logger.info("YARA rules compiled successfully")
else:
logger.warning("No valid YARA rules compiled")
else:
logger.warning("YARA rules unavailable")
except Exception as e:
logger.error(f"Initialization error: {str(e)}")
raise
yield
# Cleanup
app.state.model = None
app.state.yara_rules = None
app = FastAPI(lifespan=lifespan)
@app.get("/")
async def greet_json():
return {"Hello": "World!"}
@app.post("/upload", status_code=status.HTTP_201_CREATED)
async def upload_file(file: UploadFile = File(...)):
"""Handle file uploads with async operations and path sanitization"""
try:
# Sanitize filename to prevent path traversal
filename = Path(file.filename).name
file_path = os.path.join(UPLOAD_DIR, filename)
async with aiofiles.open(file_path, "wb") as buffer:
content = await file.read()
await buffer.write(content)
logger.info(f"File uploaded successfully: {filename}")
return {"filename": filename, "message": "File uploaded successfully"}
except Exception as e:
logger.error(f"Error uploading file: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error uploading file: {str(e)}"
)
@app.get("/files")
async def list_files():
"""List all uploaded files"""
try:
files = os.listdir(UPLOAD_DIR)
filtered_files = [file for file in files if not file.startswith('temp')]
return {"files": filtered_files}
except Exception as e:
logger.error(f"Error listing files: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error listing files: {str(e)}"
)
@app.get("/download/{file_name}")
async def download_file(file_name: str):
"""Serve files for download with proper security checks"""
# Sanitize input filename
sanitized_name = Path(file_name).name
file_path = os.path.join(UPLOAD_DIR, sanitized_name)
if not os.path.exists(file_path):
logger.warning(f"File not found: {sanitized_name}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="File not found"
)
return FileResponse(file_path, filename=sanitized_name)
def softmax(x):
exp_x = np.exp(x - np.max(x)) # Subtract max for numerical stability
return exp_x / np.sum(exp_x)
def predict_malware(img_array: np.ndarray) -> str:
"""Make prediction using the preloaded model"""
try:
prediction = app.state.model.predict(img_array)
probabilities = softmax(prediction).flatten().tolist()
return {
"result": MAL_CLASSES[np.argmax(prediction)],
"all_results": dict(zip(MAL_CLASSES, probabilities))
}
except Exception as e:
logger.error(f"Prediction error: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error processing prediction"
)
async def process_image(file_path: str, target_size: tuple = (64, 64)) -> np.ndarray:
"""Process image file for model input"""
try:
img = tf.keras.utils.load_img(file_path, target_size=target_size, color_mode="grayscale")
img_array = tf.keras.utils.img_to_array(img)
return tf.expand_dims(img_array, axis=0)
except Exception as e:
logger.error(f"Image processing error: {str(e)}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid image file format"
)
# @app.get("/analyse/{file_name}")
# async def analyse(file_name: str):
# """Analyze image files"""
# sanitized_name = Path(file_name).name
# file_path = os.path.join(UPLOAD_DIR, sanitized_name)
# if not os.path.exists(file_path):
# raise HTTPException(
# status_code=status.HTTP_404_NOT_FOUND,
# detail="File not found"
# )
# try:
# img_array = await process_image(file_path)
# result = predict_malware(img_array)
# return result
# except HTTPException as he:
# raise he
# except Exception as e:
# logger.error(f"Analysis error: {str(e)}")
# raise HTTPException(
# status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
# detail=f"Analysis failed: {str(e)}"
# )
def get_image_width(file_path: str) -> int:
"""Determine image width based on file size"""
file_size_kb = os.path.getsize(file_path) / 1024
size_ranges = [
(0, 10, 32),
(10, 30, 64),
(30, 60, 128),
(60, 100, 256),
(100, 200, 384),
(200, 500, 512),
(500, 1000, 768),
(1000, float('inf'), 1024)
]
for lower, upper, width in size_ranges:
if lower <= file_size_kb < upper:
return width
raise ValueError("File size out of expected range")
def convert_binary_to_image(binary_path: str, output_path: str, width: int):
"""Convert binary file to grayscale image with error handling"""
try:
with open(binary_path, "rb") as f:
binary_data = f.read()
grayscale = np.frombuffer(binary_data, dtype=np.uint8)
height = len(grayscale) // width
grayscale = grayscale[:height*width].reshape((height, width))
cv2.imwrite(output_path, grayscale)
logger.debug(f"Image converted: {output_path}")
except Exception as e:
logger.error(f"Binary conversion error: {str(e)}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid binary file format"
)
@app.get("/image/{file_name}")
async def image(file_name: str):
if os.path.exists(os.path.join(UPLOAD_DIR, file_name)):
return FileResponse(os.path.join(UPLOAD_DIR, file_name), media_type="image/png")
return JSONResponse(content={"error": "Image not found"})
@app.get("/analysebin/{file_name}")
async def analyse_bin(file_name: str):
"""Analyze binary files by converting to images"""
sanitized_name = Path(file_name).name
file_path = os.path.join(UPLOAD_DIR, sanitized_name)
if not os.path.exists(file_path):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="File not found"
)
temp_image = os.path.join(UPLOAD_DIR, f"temp_{uuid.uuid4()}.png")
try:
width = get_image_width(file_path)
convert_binary_to_image(file_path, temp_image, width)
img_array = await process_image(temp_image)
result = predict_malware(img_array)
return {
**result,
"image_location": str(temp_image)[8:]
}
except HTTPException as he:
raise he
except Exception as e:
logger.error(f"Binary analysis error: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Binary analysis failed: {str(e)}"
)
@app.get("/analyse/yara/{file_name}")
async def analyse_yara(file_name: str):
"""Analyze file using YARA rules from the GitHub repository"""
sanitized_name = Path(file_name).name
file_path = os.path.join(UPLOAD_DIR, sanitized_name)
if not os.path.exists(file_path):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="File not found"
)
if not app.state.yara_rules:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="YARA rules not available"
)
try:
matches = await run_in_threadpool(
app.state.yara_rules.match,
file_path
)
# matches = app.state.yara_rules.match(file_path)
if matches:
result = {"result": "Found",
"matches": [{
"rule": match.rule,
"namespace": match.namespace,
"tags": match.tags,
"meta": match.meta,
# "strings": [s for s in match.strings]
} for match in matches]
}
return result
else:
return {"result": "Does not match"}
except Exception as e:
logger.error(f"YARA analysis failed: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="YARA analysis failed"
)