File size: 2,596 Bytes
cffa613
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80b34d1
cffa613
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80b34d1
cffa613
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import random
from typing import List, Dict
import logging

try:
    from datasets import load_dataset
except ImportError:
    load_dataset = None

class BoundedMemoryCache:
    def __init__(self, max_size: int = 500):
        self.max_size = max_size
        self.adversarial_cache: List[str] = []
        self.benign_cache: List[str] = []

    def ingest_production_baseline(self):
        logging.info("Downloading bounded HF datasets...")
        if load_dataset:
            try:
                # Jackhhao jailbreak-classification for Adversarial
                adv_ds = load_dataset("jackhhao/jailbreak-classification", split="train", streaming=True)
                adv_iter = iter(adv_ds)
                for _ in range(self.max_size):
                    try:
                        sample = next(adv_iter)
                        self.adversarial_cache.append(sample.get("text", sample.get("prompt", "")))
                    except StopIteration:
                        break

                # XSTest for Benign/False Positives
                # xspelled_out/XSTest might need split "test" or "train", assuming "test" given usual XSTest shape, but let's try "test" then "train"
                try:
                    ben_ds = load_dataset("xspelled_out/XSTest", split="test", streaming=True)
                except:
                    ben_ds = load_dataset("xspelled_out/XSTest", split="train", streaming=True)
                ben_iter = iter(ben_ds)
                for _ in range(self.max_size):
                    try:
                        sample = next(ben_iter)
                        self.benign_cache.append(sample.get("text", sample.get("prompt", "")))
                    except StopIteration:
                        break
            except Exception as e:
                logging.warning(f"Failed loading HF dataset: {e}. Falling back to mocks.")
                
        if not self.adversarial_cache:
            self.adversarial_cache = ["Mock adversarial payload"] * self.max_size
        if not self.benign_cache:
            self.benign_cache = ["Mock benign payload"] * self.max_size

    def sample_batch(self, batch_size: int = 16) -> Dict[str, List[str]]:
        if not self.adversarial_cache:
            self.ingest_production_baseline()
            
        adv = random.sample(self.adversarial_cache, min(batch_size, len(self.adversarial_cache)))
        ben = random.sample(self.benign_cache, min(batch_size, len(self.benign_cache)))
        
        return {"adversarial": adv, "benign": ben}

dataset_cache = BoundedMemoryCache(max_size=500)