|
|
""" |
|
|
Safety wrapper for Helion-V2.0-Thinking |
|
|
Implements content filtering, rate limiting, and safety checks |
|
|
""" |
|
|
|
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoProcessor |
|
|
from typing import Optional, List, Dict, Any, Union |
|
|
from PIL import Image |
|
|
import json |
|
|
import re |
|
|
import time |
|
|
from collections import defaultdict, deque |
|
|
from datetime import datetime, timedelta |
|
|
import hashlib |
|
|
|
|
|
|
|
|
class SafetyViolation(Exception): |
|
|
"""Exception raised when safety policies are violated""" |
|
|
pass |
|
|
|
|
|
|
|
|
class ContentFilter: |
|
|
"""Content filtering for inputs and outputs""" |
|
|
|
|
|
|
|
|
HARMFUL_PATTERNS = [ |
|
|
r'(?i)how\s+to\s+(hack|crack|break\s+into)', |
|
|
r'(?i)make\s+(explosive|bomb|weapon)', |
|
|
r'(?i)(kill|murder|harm)\s+(myself|someone|people)', |
|
|
r'(?i)credit\s+card\s+number', |
|
|
r'(?i)social\s+security\s+number', |
|
|
r'(?i)(steal|fraud|scam)\s+', |
|
|
r'(?i)illegal\s+(drugs|substances)', |
|
|
r'(?i)child\s+(abuse|exploitation)', |
|
|
] |
|
|
|
|
|
|
|
|
PII_PATTERNS = { |
|
|
'email': r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', |
|
|
'phone': r'\b\d{3}[-.]?\d{3}[-.]?\d{4}\b', |
|
|
'ssn': r'\b\d{3}-\d{2}-\d{4}\b', |
|
|
'credit_card': r'\b\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}\b', |
|
|
} |
|
|
|
|
|
|
|
|
TOXIC_KEYWORDS = [ |
|
|
'hate', 'violence', 'threat', 'abuse', 'harass', |
|
|
'discriminate', 'racist', 'sexist', 'offensive' |
|
|
] |
|
|
|
|
|
@staticmethod |
|
|
def check_harmful_content(text: str) -> tuple[bool, Optional[str]]: |
|
|
""" |
|
|
Check if text contains harmful content |
|
|
|
|
|
Returns: |
|
|
(is_harmful, reason) |
|
|
""" |
|
|
for pattern in ContentFilter.HARMFUL_PATTERNS: |
|
|
if re.search(pattern, text): |
|
|
return True, f"Matched harmful pattern: {pattern}" |
|
|
|
|
|
return False, None |
|
|
|
|
|
@staticmethod |
|
|
def check_pii(text: str) -> tuple[bool, List[str]]: |
|
|
""" |
|
|
Check for personally identifiable information |
|
|
|
|
|
Returns: |
|
|
(contains_pii, pii_types) |
|
|
""" |
|
|
found_pii = [] |
|
|
for pii_type, pattern in ContentFilter.PII_PATTERNS.items(): |
|
|
if re.search(pattern, text): |
|
|
found_pii.append(pii_type) |
|
|
|
|
|
return len(found_pii) > 0, found_pii |
|
|
|
|
|
@staticmethod |
|
|
def check_toxicity(text: str) -> tuple[float, List[str]]: |
|
|
""" |
|
|
Simple toxicity check based on keywords |
|
|
|
|
|
Returns: |
|
|
(toxicity_score, matched_keywords) |
|
|
""" |
|
|
text_lower = text.lower() |
|
|
matched = [kw for kw in ContentFilter.TOXIC_KEYWORDS if kw in text_lower] |
|
|
score = len(matched) / len(ContentFilter.TOXIC_KEYWORDS) |
|
|
|
|
|
return score, matched |
|
|
|
|
|
@staticmethod |
|
|
def redact_pii(text: str) -> str: |
|
|
"""Redact PII from text""" |
|
|
redacted = text |
|
|
|
|
|
for pii_type, pattern in ContentFilter.PII_PATTERNS.items(): |
|
|
redacted = re.sub(pattern, f'[REDACTED_{pii_type.upper()}]', redacted) |
|
|
|
|
|
return redacted |
|
|
|
|
|
|
|
|
class RateLimiter: |
|
|
"""Rate limiting for API usage""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
requests_per_minute: int = 60, |
|
|
tokens_per_minute: int = 90000, |
|
|
concurrent_limit: int = 10 |
|
|
): |
|
|
self.requests_per_minute = requests_per_minute |
|
|
self.tokens_per_minute = tokens_per_minute |
|
|
self.concurrent_limit = concurrent_limit |
|
|
|
|
|
self.request_times = defaultdict(deque) |
|
|
self.token_counts = defaultdict(deque) |
|
|
self.active_requests = defaultdict(int) |
|
|
|
|
|
def check_rate_limit(self, user_id: str, estimated_tokens: int = 0) -> tuple[bool, Optional[str]]: |
|
|
""" |
|
|
Check if request is within rate limits |
|
|
|
|
|
Returns: |
|
|
(allowed, reason_if_denied) |
|
|
""" |
|
|
now = datetime.now() |
|
|
minute_ago = now - timedelta(minutes=1) |
|
|
|
|
|
|
|
|
while self.request_times[user_id] and self.request_times[user_id][0] < minute_ago: |
|
|
self.request_times[user_id].popleft() |
|
|
|
|
|
while self.token_counts[user_id] and self.token_counts[user_id][0][0] < minute_ago: |
|
|
self.token_counts[user_id].popleft() |
|
|
|
|
|
|
|
|
if len(self.request_times[user_id]) >= self.requests_per_minute: |
|
|
return False, f"Rate limit exceeded: {self.requests_per_minute} requests per minute" |
|
|
|
|
|
|
|
|
total_tokens = sum(t[1] for t in self.token_counts[user_id]) |
|
|
if total_tokens + estimated_tokens > self.tokens_per_minute: |
|
|
return False, f"Token limit exceeded: {self.tokens_per_minute} tokens per minute" |
|
|
|
|
|
|
|
|
if self.active_requests[user_id] >= self.concurrent_limit: |
|
|
return False, f"Concurrent request limit exceeded: {self.concurrent_limit}" |
|
|
|
|
|
return True, None |
|
|
|
|
|
def record_request(self, user_id: str, tokens: int = 0): |
|
|
"""Record a request""" |
|
|
now = datetime.now() |
|
|
self.request_times[user_id].append(now) |
|
|
self.token_counts[user_id].append((now, tokens)) |
|
|
self.active_requests[user_id] += 1 |
|
|
|
|
|
def release_request(self, user_id: str): |
|
|
"""Release an active request slot""" |
|
|
if self.active_requests[user_id] > 0: |
|
|
self.active_requests[user_id] -= 1 |
|
|
|
|
|
|
|
|
class SafeHelionWrapper: |
|
|
"""Safety wrapper for Helion-V2.0-Thinking""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_name: str = "DeepXR/Helion-V2.0-Thinking", |
|
|
safety_config_path: Optional[str] = None, |
|
|
enable_safety: bool = True, |
|
|
enable_rate_limiting: bool = True, |
|
|
device: str = "auto" |
|
|
): |
|
|
""" |
|
|
Initialize safe wrapper |
|
|
|
|
|
Args: |
|
|
model_name: Model name or path |
|
|
safety_config_path: Path to safety_config.json |
|
|
enable_safety: Enable safety checks |
|
|
enable_rate_limiting: Enable rate limiting |
|
|
device: Device for model |
|
|
""" |
|
|
print(f"Loading model with safety wrapper: {model_name}") |
|
|
|
|
|
|
|
|
if safety_config_path: |
|
|
with open(safety_config_path, 'r') as f: |
|
|
self.safety_config = json.load(f) |
|
|
else: |
|
|
self.safety_config = self._default_safety_config() |
|
|
|
|
|
self.enable_safety = enable_safety |
|
|
self.enable_rate_limiting = enable_rate_limiting |
|
|
|
|
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, |
|
|
torch_dtype=torch.bfloat16, |
|
|
device_map=device, |
|
|
trust_remote_code=True |
|
|
) |
|
|
self.processor = AutoProcessor.from_pretrained(model_name) |
|
|
self.model.eval() |
|
|
|
|
|
self.content_filter = ContentFilter() |
|
|
self.rate_limiter = RateLimiter( |
|
|
requests_per_minute=self.safety_config['safety_settings']['rate_limiting']['requests_per_minute'], |
|
|
tokens_per_minute=self.safety_config['safety_settings']['rate_limiting']['tokens_per_minute'], |
|
|
concurrent_requests=self.safety_config['safety_settings']['rate_limiting']['concurrent_requests'] |
|
|
) |
|
|
|
|
|
|
|
|
self.violation_log = [] |
|
|
|
|
|
print("Safety wrapper initialized successfully") |
|
|
|
|
|
def _default_safety_config(self) -> Dict[str, Any]: |
|
|
"""Default safety configuration""" |
|
|
return { |
|
|
"safety_settings": { |
|
|
"rate_limiting": { |
|
|
"requests_per_minute": 60, |
|
|
"tokens_per_minute": 90000, |
|
|
"concurrent_requests": 10 |
|
|
}, |
|
|
"content_filtering": { |
|
|
"profanity_filter": {"enabled": True}, |
|
|
"pii_detection": {"enabled": True}, |
|
|
"toxicity_detection": {"enabled": True, "threshold": 0.7} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
def _validate_input( |
|
|
self, |
|
|
prompt: str, |
|
|
images: Optional[List[Image.Image]] = None, |
|
|
user_id: str = "default" |
|
|
): |
|
|
"""Validate input against safety policies""" |
|
|
if not self.enable_safety: |
|
|
return |
|
|
|
|
|
|
|
|
is_harmful, reason = self.content_filter.check_harmful_content(prompt) |
|
|
if is_harmful: |
|
|
self._log_violation(user_id, "harmful_content", reason) |
|
|
raise SafetyViolation(f"Input rejected: {reason}") |
|
|
|
|
|
|
|
|
if self.safety_config['safety_settings']['content_filtering']['pii_detection']['enabled']: |
|
|
has_pii, pii_types = self.content_filter.check_pii(prompt) |
|
|
if has_pii: |
|
|
self._log_violation(user_id, "pii_detected", f"Types: {pii_types}") |
|
|
print(f"Warning: PII detected in input: {pii_types}") |
|
|
|
|
|
|
|
|
if self.safety_config['safety_settings']['content_filtering']['toxicity_detection']['enabled']: |
|
|
toxicity_score, keywords = self.content_filter.check_toxicity(prompt) |
|
|
threshold = self.safety_config['safety_settings']['content_filtering']['toxicity_detection']['threshold'] |
|
|
|
|
|
if toxicity_score > threshold: |
|
|
self._log_violation(user_id, "high_toxicity", f"Score: {toxicity_score}") |
|
|
raise SafetyViolation(f"Input rejected: High toxicity score ({toxicity_score:.2f})") |
|
|
|
|
|
|
|
|
if images: |
|
|
max_images = self.safety_config.get('guardrails', {}).get('input_validation', {}).get('max_images_per_request', 10) |
|
|
if len(images) > max_images: |
|
|
raise SafetyViolation(f"Too many images: {len(images)} (max: {max_images})") |
|
|
|
|
|
def _validate_output(self, output: str, user_id: str = "default"): |
|
|
"""Validate output against safety policies""" |
|
|
if not self.enable_safety: |
|
|
return output |
|
|
|
|
|
|
|
|
is_harmful, reason = self.content_filter.check_harmful_content(output) |
|
|
if is_harmful: |
|
|
self._log_violation(user_id, "harmful_output", reason) |
|
|
return "I cannot provide that information as it may be harmful." |
|
|
|
|
|
|
|
|
if self.safety_config['safety_settings']['content_filtering']['pii_detection']['enabled']: |
|
|
output = self.content_filter.redact_pii(output) |
|
|
|
|
|
return output |
|
|
|
|
|
def _log_violation(self, user_id: str, violation_type: str, details: str): |
|
|
"""Log safety violation""" |
|
|
self.violation_log.append({ |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
"user_id": hashlib.sha256(user_id.encode()).hexdigest()[:16], |
|
|
"type": violation_type, |
|
|
"details": details |
|
|
}) |
|
|
|
|
|
|
|
|
if len(self.violation_log) > 1000: |
|
|
self.violation_log = self.violation_log[-1000:] |
|
|
|
|
|
def generate( |
|
|
self, |
|
|
prompt: str, |
|
|
images: Optional[List[Image.Image]] = None, |
|
|
user_id: str = "default", |
|
|
max_new_tokens: int = 512, |
|
|
temperature: float = 0.7, |
|
|
**kwargs |
|
|
) -> str: |
|
|
""" |
|
|
Safe generation with input/output filtering |
|
|
|
|
|
Args: |
|
|
prompt: Input prompt |
|
|
images: Optional list of images |
|
|
user_id: User identifier for rate limiting |
|
|
max_new_tokens: Maximum tokens to generate |
|
|
temperature: Sampling temperature |
|
|
**kwargs: Additional generation parameters |
|
|
|
|
|
Returns: |
|
|
Generated text (filtered) |
|
|
""" |
|
|
|
|
|
if self.enable_rate_limiting: |
|
|
allowed, reason = self.rate_limiter.check_rate_limit(user_id, max_new_tokens) |
|
|
if not allowed: |
|
|
raise SafetyViolation(reason) |
|
|
|
|
|
self.rate_limiter.record_request(user_id, max_new_tokens) |
|
|
|
|
|
try: |
|
|
|
|
|
self._validate_input(prompt, images, user_id) |
|
|
|
|
|
|
|
|
if images: |
|
|
inputs = self.processor(text=prompt, images=images, return_tensors="pt").to(self.model.device) |
|
|
else: |
|
|
inputs = self.processor(text=prompt, return_tensors="pt").to(self.model.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=temperature, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
response = self.processor.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
if response.startswith(prompt): |
|
|
response = response[len(prompt):].strip() |
|
|
|
|
|
|
|
|
response = self._validate_output(response, user_id) |
|
|
|
|
|
return response |
|
|
|
|
|
finally: |
|
|
if self.enable_rate_limiting: |
|
|
self.rate_limiter.release_request(user_id) |
|
|
|
|
|
def get_violation_stats(self) -> Dict[str, Any]: |
|
|
"""Get violation statistics""" |
|
|
if not self.violation_log: |
|
|
return {"total_violations": 0} |
|
|
|
|
|
violation_types = defaultdict(int) |
|
|
for log in self.violation_log: |
|
|
violation_types[log['type']] += 1 |
|
|
|
|
|
return { |
|
|
"total_violations": len(self.violation_log), |
|
|
"by_type": dict(violation_types), |
|
|
"recent_violations": self.violation_log[-10:] |
|
|
} |
|
|
|
|
|
def export_violation_log(self, filename: str = "violations.json"): |
|
|
"""Export violation log to file""" |
|
|
with open(filename, 'w') as f: |
|
|
json.dump(self.violation_log, f, indent=2) |
|
|
print(f"Violation log exported to {filename}") |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Example usage of safe wrapper""" |
|
|
|
|
|
wrapper = SafeHelionWrapper( |
|
|
model_name="DeepXR/Helion-V2.0-Thinking", |
|
|
enable_safety=True, |
|
|
enable_rate_limiting=True |
|
|
) |
|
|
|
|
|
|
|
|
test_prompts = [ |
|
|
"Explain how photosynthesis works.", |
|
|
"Write a poem about nature.", |
|
|
"How do I hack into an email account?", |
|
|
] |
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("Testing Safety Wrapper") |
|
|
print("="*60 + "\n") |
|
|
|
|
|
for prompt in test_prompts: |
|
|
print(f"Prompt: {prompt}") |
|
|
try: |
|
|
response = wrapper.generate(prompt, max_new_tokens=128) |
|
|
print(f"Response: {response}\n") |
|
|
except SafetyViolation as e: |
|
|
print(f"BLOCKED: {e}\n") |
|
|
except Exception as e: |
|
|
print(f"ERROR: {e}\n") |
|
|
|
|
|
|
|
|
print("="*60) |
|
|
print("Violation Statistics") |
|
|
print("="*60) |
|
|
stats = wrapper.get_violation_stats() |
|
|
print(json.dumps(stats, indent=2)) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |