File size: 5,085 Bytes
3d7d9b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
use adblock::{Engine, FilterSet};
use adblock::lists::ParseOptions;
use adblock::request::Request;
use std::collections::HashSet;
use std::sync::RwLock;
use std::sync::atomic::{AtomicU64, Ordering};
use serde::Serialize;

/// Thread-safe wrapper around adblock Engine
pub struct AdBlockState {
    pub engine: RwLock<Engine>,
    pub stats: AdBlockStats,
    pub allowlist: RwLock<HashSet<String>>,
    pub rule_count: AtomicU64,
}

#[derive(Default)]
pub struct AdBlockStats {
    pub blocked_requests: AtomicU64,
    pub blocked_cosmetic: AtomicU64,
    pub https_upgrades: AtomicU64,
}

#[derive(Debug, Clone, Serialize)]
pub struct ShieldReport {
    pub blocked_requests: u64,
    pub blocked_cosmetic: u64,
    pub https_upgrades: u64,
    pub engine_rules: usize,
    pub allowlisted_domains: usize,
}

impl AdBlockState {
    pub fn new() -> Self {
        let (engine, count) = build_engine_from_bundled();
        Self {
            engine: RwLock::new(engine),
            stats: AdBlockStats::default(),
            allowlist: RwLock::new(HashSet::new()),
            rule_count: AtomicU64::new(count as u64),
        }
    }

    pub fn should_block(&self, url: &str, source_url: &str, request_type: &str) -> bool {
        if let Some(domain) = extract_domain(source_url) {
            if let Ok(allowlist) = self.allowlist.read() {
                if allowlist.contains(&domain) {
                    return false;
                }
            }
        }

        let engine = match self.engine.read() {
            Ok(e) => e,
            Err(_) => return false,
        };

        match Request::new(url, source_url, request_type) {
            Ok(req) => {
                let result = engine.check_network_request(&req);
                if result.matched && result.exception.is_none() {
                    self.stats.blocked_requests.fetch_add(1, Ordering::Relaxed);
                    true
                } else {
                    false
                }
            }
            Err(_) => false,
        }
    }

    pub fn get_cosmetic_css(&self, url: &str) -> String {
        let engine = match self.engine.read() {
            Ok(e) => e,
            Err(_) => return String::new(),
        };

        let resources = engine.url_cosmetic_resources(url);
        if resources.hide_selectors.is_empty() {
            return String::new();
        }

        let selectors: Vec<&String> = resources.hide_selectors.iter().collect();
        let css = selectors.iter()
            .map(|s| format!("{s}{{display:none!important}}"))
            .collect::<Vec<_>>()
            .join("\n");

        self.stats.blocked_cosmetic.fetch_add(selectors.len() as u64, Ordering::Relaxed);
        css
    }

    /// Get injected scriptlets for a URL (json-prune, set-constant, etc.)
    /// Used for YouTube/Twitch/video ad blocking
    pub fn get_injected_script(&self, url: &str) -> String {
        let engine = match self.engine.read() {
            Ok(e) => e,
            Err(_) => return String::new(),
        };
        let resources = engine.url_cosmetic_resources(url);
        resources.injected_script
    }

    pub fn report(&self) -> ShieldReport {
        let engine_rules = self.rule_count.load(Ordering::Relaxed) as usize;
        let allowlisted = self.allowlist.read().map(|a| a.len()).unwrap_or(0);
        ShieldReport {
            blocked_requests: self.stats.blocked_requests.load(Ordering::Relaxed),
            blocked_cosmetic: self.stats.blocked_cosmetic.load(Ordering::Relaxed),
            https_upgrades: self.stats.https_upgrades.load(Ordering::Relaxed),
            engine_rules,
            allowlisted_domains: allowlisted,
        }
    }

    pub fn reload_engine(&self, new_engine: Engine, rule_count: usize) {
        if let Ok(mut engine) = self.engine.write() {
            *engine = new_engine;
        }
        self.rule_count.store(rule_count as u64, Ordering::Relaxed);
    }
}

impl Default for AdBlockState {
    fn default() -> Self {
        Self::new()
    }
}

pub fn build_engine_from_bundled() -> (Engine, usize) {
    let mut filter_set = FilterSet::new(false);

    let lists: &[&str] = &[
        include_str!("../../resources/filters/easylist_mini.txt"),
        include_str!("../../resources/filters/easyprivacy_mini.txt"),
        include_str!("../../resources/filters/annoyances_mini.txt"),
    ];

    let mut count = 0usize;
    for list in lists {
        let _meta = filter_set.add_filter_list(list, ParseOptions::default());
        count += list.lines().filter(|l| !l.starts_with('!') && !l.trim().is_empty()).count();
    }

    (Engine::from_filter_set(filter_set, true), count)
}

fn extract_domain(url: &str) -> Option<String> {
    url.split("//")
        .nth(1)
        .and_then(|s| s.split('/').next())
        .map(|s| s.split(':').next().unwrap_or(s).to_lowercase())
        .map(|s| s.trim_start_matches("www.").to_string())
}