File size: 5,113 Bytes
3a60eea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71d82e5
3a60eea
 
 
 
5712cc4
3a60eea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26b720e
 
5712cc4
3a60eea
 
5712cc4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a60eea
5712cc4
 
 
3a60eea
5712cc4
3a60eea
5712cc4
 
3a60eea
 
5712cc4
 
 
 
 
 
 
3a60eea
 
5712cc4
71d82e5
 
3a60eea
 
 
26b720e
 
d802802
fee53f8
d802802
fee53f8
d802802
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fee53f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26b720e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
def extract_body_text_by_string(input_string, max_len=512):
    string_length = len(input_string)

    if string_length <= max_len:
        return input_string.strip()

    chunk_size = max_len // 3  # 三等分
    positions = [0, string_length // 2, string_length - chunk_size]  # 头、中、尾

    extracted_text = []

    for pos in positions:
        text_chunk = input_string[pos : pos + chunk_size]
        extracted_text.append(text_chunk.strip())

    return "".join(extracted_text)


import torch

DEVICE = torch.device(
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)
print("device", DEVICE)

import numpy as np


def f_predict_text(text, model, tokenizer, max_len, device=DEVICE):
    encoding = tokenizer.encode_plus(
        text,
        add_special_tokens=True,
        max_length=max_len,
        padding="max_length",
        truncation=True,
        return_attention_mask=True,
        return_tensors="pt",
    )

    input_ids = encoding["input_ids"].to(device)
    attention_mask = encoding["attention_mask"].to(device)

    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask)
        logits = outputs.logits.cpu().numpy()
        pred = np.argmax(logits, axis=1)[0]

    return "ok" if pred == 0 else "ban"


import gradio as gr

_models = None


def load_models():
    print("loading models...")

    from transformers import (
        BertTokenizer,
        BertForSequenceClassification,
    )

    head_tokenizer = BertTokenizer.from_pretrained(
        f"e1732a364fed/bert-geosite-classification-head-v1"
    )
    body_tokenizer = BertTokenizer.from_pretrained(
        f"e1732a364fed/bert-geosite-classification-body-v1"
    )
    head_model = BertForSequenceClassification.from_pretrained(
        f"e1732a364fed/bert-geosite-classification-head-v1"
    ).to(DEVICE)
    body_model = BertForSequenceClassification.from_pretrained(
        f"e1732a364fed/bert-geosite-classification-body-v1"
    ).to(DEVICE)

    head_model.eval()
    body_model.eval()
    print("loaded models...")

    return head_tokenizer, body_tokenizer, head_model, body_model


def predict_by_text(head, body):
    print("predicting head")

    global _models
    if _models is None:
        print("loading models...")
        _models = load_models()
    head_tokenizer, body_tokenizer, head_model, body_model = _models

    h_result = f_predict_text(head, head_model, head_tokenizer, 512)
    print("predicting body")

    b_result = f_predict_text(
        extract_body_text_by_string(body), body_model, body_tokenizer, 512
    )
    print("prediction done")

    return h_result, b_result


with gr.Blocks() as demo:
    btn1 = gr.Button("Classify by Text")

    @btn1.click(
        inputs=[
            gr.Textbox(
                label="Head",
                info="http response head",
                lines=6,
                max_lines=10000,
                value="200 OK",
            ),
            gr.Textbox(
                label="Body",
                info="http response body",
                lines=6,
                max_lines=10000000000,
                value="<body>The quick brown fox jumped over the lazy dogs.</body>",
            ),
        ],
        outputs=[
            gr.Textbox(
                label="Head",
                info="http response prediction",
            ),
            gr.Textbox(
                label="Body",
                info="http response prediction",
            ),
        ],
    )
    def f(h, b):
        return predict_by_text(h, b)

    btn2 = gr.Button("Classify by Website URL")

    @btn2.click(
        inputs=[
            gr.Textbox(
                label="Website URL",
                info="http response head",
                value="https://httpbin.org/get",
            )
        ],
        outputs=[
            gr.Textbox(
                label="Head Classify",
            ),
            gr.Textbox(
                label="Body Classify",
            ),
            gr.Textbox(
                label="Head",
                lines=6,
            ),
            gr.Textbox(
                label="Body",
                lines=6,
            ),
        ],
    )
    def f2(url):
        import requests

        headers = {
            "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:139.0) Gecko/20100101 Firefox/139.0",
            "Accept": "tetext/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
            "Accept-Language": "zh-CN,zh;q=0.9",
            "Cache-Control": "max-age=0",
            "Connection": "keep-alive",
        }

        response = requests.get(url, headers=headers, timeout=4)
        if isinstance(response, str):
            r = {"error": f"Request failed: {response}"}
            return r, r, r, r

        h_text = [f"{key}: {value}" for key, value in response.headers.items()]
        h = "\n".join(h_text)
        b = extract_body_text_by_string(response.text)
        print(response)

        hr, br = predict_by_text(h, b)
        return hr, br, h, b


demo.launch()