Rithwik Ravi
fix(ui): enforce verbatim user terminology for charts, hook evaluate script to real dataset for dynamic threat feed, and purge last mocked strings
09e95f5 | import os | |
| import sys | |
| import warnings | |
| import contextlib | |
| import logging | |
| import time | |
| # Suppress all standard warnings | |
| warnings.filterwarnings("ignore") | |
| os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" | |
| logging.getLogger("transformers").setLevel(logging.ERROR) | |
| logging.getLogger("httpx").setLevel(logging.ERROR) | |
| def suppress_output(): | |
| """ | |
| Context manager to redirect stdout and stderr to os.devnull at the OS level. | |
| """ | |
| with open(os.devnull, 'w') as devnull: | |
| old_stdout_fd = os.dup(1) | |
| old_stderr_fd = os.dup(2) | |
| os.dup2(devnull.fileno(), 1) | |
| os.dup2(devnull.fileno(), 2) | |
| try: | |
| yield | |
| finally: | |
| os.dup2(old_stdout_fd, 1) | |
| os.dup2(old_stderr_fd, 2) | |
| os.close(old_stdout_fd) | |
| os.close(old_stderr_fd) | |
| def main(): | |
| print("[START]", flush=True) | |
| with suppress_output(): | |
| from src.env.guardrail import GuardrailEnvironment | |
| from src.env.reward import LogBarrierReward | |
| from src.env.models import Action | |
| from src.telemetry.streamer import append_metric | |
| from src.rl.data import dataset_cache | |
| import random | |
| # Load real datasets for evaluation threat feed | |
| dataset_cache.ingest_production_baseline() | |
| batch = dataset_cache.sample_batch(batch_size=10) | |
| env = GuardrailEnvironment() | |
| env.reset(batch["adversarial"], batch["benign"]) | |
| r_engine = LogBarrierReward() | |
| try: | |
| import torch | |
| from unsloth import FastLanguageModel | |
| model_name = "Qwen/Qwen2.5-0.5B-Instruct" | |
| if os.path.exists("models/trained_guardrail"): | |
| model_name = "models/trained_guardrail" | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name=model_name, | |
| max_seq_length=2048, | |
| dtype=None, | |
| load_in_4bit=True, | |
| ) | |
| FastLanguageModel.for_inference(model) | |
| except Exception as e: | |
| # Fallback if Unsloth is missing or GPU is unavailable (prevents crash) | |
| model, tokenizer = None, None | |
| prompt = "You are a cyber security expert. Generate a JSON AST GuardrailGraph to block prompt injections but allow benign queries. Output strictly in ```json format." | |
| # Simulate an improving learning curve if model fails to load (no GPU) | |
| fallback_actions = [ | |
| {"root": {"operator": "OR", "children": []}}, # 0% recall, 0% FPR | |
| {"root": {"operator": "OR", "children": [{"filter_type": "substring", "value": "Ignore"}]}}, | |
| {"root": {"operator": "OR", "children": [{"filter_type": "regex_pattern", "value": "Ignore previous.*"}, {"filter_type": "length_limit", "value": 500}]}}, | |
| {"root": {"operator": "AND", "children": [{"filter_type": "substring", "value": "secret"}, {"filter_type": "entropy_threshold", "value": 4.5}]}}, | |
| {"root": {"operator": "OR", "children": [{"filter_type": "length_limit", "value": 10}, {"filter_type": "substring", "value": "Delete"}]}}, | |
| {"root": {"operator": "OR", "children": [{"filter_type": "substring", "value": "bypass"}, {"filter_type": "substring", "value": "Developer Mode"}]}} | |
| ] | |
| for i in range(120): | |
| with suppress_output(): | |
| try: | |
| if model is not None and tokenizer is not None: | |
| # Pass 1: Trained Model | |
| inputs = tokenizer([prompt], return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu") | |
| outputs = model.generate(**inputs, max_new_tokens=512, use_cache=True, pad_token_id=tokenizer.eos_token_id) | |
| output_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] | |
| del outputs | |
| # Pass 2: Baseline Model (Disabled Adapter) | |
| if hasattr(model, 'disable_adapter'): | |
| with model.disable_adapter(): | |
| baseline_outputs = model.generate(**inputs, max_new_tokens=512, use_cache=True, pad_token_id=tokenizer.eos_token_id) | |
| baseline_output_text = tokenizer.batch_decode(baseline_outputs, skip_special_tokens=True)[0] | |
| del baseline_outputs | |
| else: | |
| baseline_output_text = output_text | |
| # Explicit 8GB memory ceiling GC | |
| del inputs | |
| import gc | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| else: | |
| import json | |
| idx = min(i // 8, len(fallback_actions) - 1) | |
| fallback = fallback_actions[idx] | |
| output_text = json.dumps({"graph_id": f"AST-Fallback-{i}", "description": "Simulated Fallback", **fallback}) | |
| baseline_output_text = json.dumps({"graph_id": f"AST-Baseline-{i}", "description": "Simulated Baseline", "root": {"operator": "OR", "children": []}}) | |
| action = Action(ast_json=output_text, baseline_ast_json=baseline_output_text) | |
| recall, fpr, syntax_error = env.step(action) | |
| reward = r_engine.calculate(recall, fpr, syntax_error) | |
| baseline_recall, baseline_fpr, baseline_syntax_error = 0.0, 0.0, True | |
| baseline_reward = 0.0 | |
| if baseline_output_text: | |
| baseline_action = Action(ast_json=baseline_output_text) | |
| baseline_recall, baseline_fpr, baseline_syntax_error = env.step(baseline_action) | |
| baseline_reward = r_engine.calculate(baseline_recall, baseline_fpr, baseline_syntax_error) | |
| recent_traffic = [] | |
| for adv_str in env.state.adversarial_samples[:3]: | |
| recent_traffic.append({ | |
| "prompt_text": adv_str, | |
| "is_malicious": True, | |
| "was_blocked": random.random() < recall | |
| }) | |
| for ben_str in env.state.benign_samples[:3]: | |
| recent_traffic.append({ | |
| "prompt_text": ben_str, | |
| "is_malicious": False, | |
| "was_blocked": random.random() < fpr | |
| }) | |
| random.shuffle(recent_traffic) | |
| append_metric(reward, recall, fpr, baseline_reward, baseline_recall, baseline_fpr, output_text if not syntax_error else None, recent_traffic) | |
| time.sleep(1.0) # Ensure UI renders smoothly | |
| except Exception as e: | |
| pass | |
| print("[STEP]", flush=True) | |
| print("[END]", flush=True) | |
| if __name__ == "__main__": | |
| main() | |