Xernive's picture
Fix Hunyuan3D error handling + enhanced logging
0e805d4
"""Security utilities for input sanitization and rate limiting."""
import time
from collections import defaultdict
from typing import Dict, List
from core.config import (
MAX_PROMPT_LENGTH,
MAX_REQUESTS_PER_HOUR,
REQUEST_WINDOW_SECONDS,
)
class SecurityManager:
"""Manages security features like rate limiting and input sanitization."""
FORBIDDEN_CHARS = ['<', '>', '|', '&', ';', '`', '$', '(', ')', '\n', '\r', '\0', '\\']
def __init__(self):
self.user_requests: Dict[str, List[float]] = defaultdict(list)
def sanitize_prompt(self, prompt: str) -> str:
"""Sanitize user input to prevent injection attacks."""
if not prompt or not prompt.strip():
raise ValueError("Prompt cannot be empty")
# Remove control characters
prompt = ''.join(c for c in prompt if c.isprintable())
# Check length
if len(prompt) > MAX_PROMPT_LENGTH:
raise ValueError(
f"Prompt too long (max {MAX_PROMPT_LENGTH} characters, got {len(prompt)})"
)
# Check for forbidden characters
for char in self.FORBIDDEN_CHARS:
if char in prompt:
raise ValueError(f"Forbidden character in prompt: {char}")
return prompt.strip()
def check_rate_limit(self, user_id: str = "default") -> None:
"""Check if user has exceeded rate limit."""
now = time.time()
# Remove old requests outside the window
self.user_requests[user_id] = [
t for t in self.user_requests[user_id]
if now - t < REQUEST_WINDOW_SECONDS
]
# Check if limit exceeded
if len(self.user_requests[user_id]) >= MAX_REQUESTS_PER_HOUR:
remaining = REQUEST_WINDOW_SECONDS - (now - self.user_requests[user_id][0])
raise ValueError(
f"Rate limit exceeded. Try again in {int(remaining / 60)} minutes."
)
# Add current request
self.user_requests[user_id].append(now)
print(f"[Security] Rate limit: {len(self.user_requests[user_id])}/{MAX_REQUESTS_PER_HOUR}")
def validate_file_size(self, file_path: str, max_size_mb: float = 100.0) -> float:
"""Validate file size before operations."""
import os
if not os.path.exists(file_path):
raise ValueError(f"File not found: {file_path}")
size_mb = os.path.getsize(file_path) / 1e6
if size_mb > max_size_mb:
raise ValueError(
f"File too large: {size_mb:.2f}MB (max {max_size_mb:.2f}MB)"
)
return size_mb