test / app.py
e1732a364fed's picture
+Classify by Website URL
fee53f8
def extract_body_text_by_string(input_string, max_len=512):
string_length = len(input_string)
if string_length <= max_len:
return input_string.strip()
chunk_size = max_len // 3 # 三等分
positions = [0, string_length // 2, string_length - chunk_size] # 头、中、尾
extracted_text = []
for pos in positions:
text_chunk = input_string[pos : pos + chunk_size]
extracted_text.append(text_chunk.strip())
return "".join(extracted_text)
import torch
DEVICE = torch.device(
"cuda"
if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available() else "cpu"
)
print("device", DEVICE)
import numpy as np
def f_predict_text(text, model, tokenizer, max_len, device=DEVICE):
encoding = tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=max_len,
padding="max_length",
truncation=True,
return_attention_mask=True,
return_tensors="pt",
)
input_ids = encoding["input_ids"].to(device)
attention_mask = encoding["attention_mask"].to(device)
with torch.no_grad():
outputs = model(input_ids, attention_mask=attention_mask)
logits = outputs.logits.cpu().numpy()
pred = np.argmax(logits, axis=1)[0]
return "ok" if pred == 0 else "ban"
import gradio as gr
_models = None
def load_models():
print("loading models...")
from transformers import (
BertTokenizer,
BertForSequenceClassification,
)
head_tokenizer = BertTokenizer.from_pretrained(
f"e1732a364fed/bert-geosite-classification-head-v1"
)
body_tokenizer = BertTokenizer.from_pretrained(
f"e1732a364fed/bert-geosite-classification-body-v1"
)
head_model = BertForSequenceClassification.from_pretrained(
f"e1732a364fed/bert-geosite-classification-head-v1"
).to(DEVICE)
body_model = BertForSequenceClassification.from_pretrained(
f"e1732a364fed/bert-geosite-classification-body-v1"
).to(DEVICE)
head_model.eval()
body_model.eval()
print("loaded models...")
return head_tokenizer, body_tokenizer, head_model, body_model
def predict_by_text(head, body):
print("predicting head")
global _models
if _models is None:
print("loading models...")
_models = load_models()
head_tokenizer, body_tokenizer, head_model, body_model = _models
h_result = f_predict_text(head, head_model, head_tokenizer, 512)
print("predicting body")
b_result = f_predict_text(
extract_body_text_by_string(body), body_model, body_tokenizer, 512
)
print("prediction done")
return h_result, b_result
with gr.Blocks() as demo:
btn1 = gr.Button("Classify by Text")
@btn1.click(
inputs=[
gr.Textbox(
label="Head",
info="http response head",
lines=6,
max_lines=10000,
value="200 OK",
),
gr.Textbox(
label="Body",
info="http response body",
lines=6,
max_lines=10000000000,
value="<body>The quick brown fox jumped over the lazy dogs.</body>",
),
],
outputs=[
gr.Textbox(
label="Head",
info="http response prediction",
),
gr.Textbox(
label="Body",
info="http response prediction",
),
],
)
def f(h, b):
return predict_by_text(h, b)
btn2 = gr.Button("Classify by Website URL")
@btn2.click(
inputs=[
gr.Textbox(
label="Website URL",
info="http response head",
value="https://httpbin.org/get",
)
],
outputs=[
gr.Textbox(
label="Head Classify",
),
gr.Textbox(
label="Body Classify",
),
gr.Textbox(
label="Head",
lines=6,
),
gr.Textbox(
label="Body",
lines=6,
),
],
)
def f2(url):
import requests
headers = {
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:139.0) Gecko/20100101 Firefox/139.0",
"Accept": "tetext/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
"Accept-Language": "zh-CN,zh;q=0.9",
"Cache-Control": "max-age=0",
"Connection": "keep-alive",
}
response = requests.get(url, headers=headers, timeout=4)
if isinstance(response, str):
r = {"error": f"Request failed: {response}"}
return r, r, r, r
h_text = [f"{key}: {value}" for key, value in response.headers.items()]
h = "\n".join(h_text)
b = extract_body_text_by_string(response.text)
print(response)
hr, br = predict_by_text(h, b)
return hr, br, h, b
demo.launch()