Spaces:
Sleeping
Sleeping
| def extract_body_text_by_string(input_string, max_len=512): | |
| string_length = len(input_string) | |
| if string_length <= max_len: | |
| return input_string.strip() | |
| chunk_size = max_len // 3 # 三等分 | |
| positions = [0, string_length // 2, string_length - chunk_size] # 头、中、尾 | |
| extracted_text = [] | |
| for pos in positions: | |
| text_chunk = input_string[pos : pos + chunk_size] | |
| extracted_text.append(text_chunk.strip()) | |
| return "".join(extracted_text) | |
| import torch | |
| DEVICE = torch.device( | |
| "cuda" | |
| if torch.cuda.is_available() | |
| else "mps" if torch.backends.mps.is_available() else "cpu" | |
| ) | |
| print("device", DEVICE) | |
| import numpy as np | |
| def f_predict_text(text, model, tokenizer, max_len, device=DEVICE): | |
| encoding = tokenizer.encode_plus( | |
| text, | |
| add_special_tokens=True, | |
| max_length=max_len, | |
| padding="max_length", | |
| truncation=True, | |
| return_attention_mask=True, | |
| return_tensors="pt", | |
| ) | |
| input_ids = encoding["input_ids"].to(device) | |
| attention_mask = encoding["attention_mask"].to(device) | |
| with torch.no_grad(): | |
| outputs = model(input_ids, attention_mask=attention_mask) | |
| logits = outputs.logits.cpu().numpy() | |
| pred = np.argmax(logits, axis=1)[0] | |
| return "ok" if pred == 0 else "ban" | |
| import gradio as gr | |
| _models = None | |
| def load_models(): | |
| print("loading models...") | |
| from transformers import ( | |
| BertTokenizer, | |
| BertForSequenceClassification, | |
| ) | |
| head_tokenizer = BertTokenizer.from_pretrained( | |
| f"e1732a364fed/bert-geosite-classification-head-v1" | |
| ) | |
| body_tokenizer = BertTokenizer.from_pretrained( | |
| f"e1732a364fed/bert-geosite-classification-body-v1" | |
| ) | |
| head_model = BertForSequenceClassification.from_pretrained( | |
| f"e1732a364fed/bert-geosite-classification-head-v1" | |
| ).to(DEVICE) | |
| body_model = BertForSequenceClassification.from_pretrained( | |
| f"e1732a364fed/bert-geosite-classification-body-v1" | |
| ).to(DEVICE) | |
| head_model.eval() | |
| body_model.eval() | |
| print("loaded models...") | |
| return head_tokenizer, body_tokenizer, head_model, body_model | |
| def predict_by_text(head, body): | |
| print("predicting head") | |
| global _models | |
| if _models is None: | |
| print("loading models...") | |
| _models = load_models() | |
| head_tokenizer, body_tokenizer, head_model, body_model = _models | |
| h_result = f_predict_text(head, head_model, head_tokenizer, 512) | |
| print("predicting body") | |
| b_result = f_predict_text( | |
| extract_body_text_by_string(body), body_model, body_tokenizer, 512 | |
| ) | |
| print("prediction done") | |
| return h_result, b_result | |
| with gr.Blocks() as demo: | |
| btn1 = gr.Button("Classify by Text") | |
| def f(h, b): | |
| return predict_by_text(h, b) | |
| btn2 = gr.Button("Classify by Website URL") | |
| def f2(url): | |
| import requests | |
| headers = { | |
| "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:139.0) Gecko/20100101 Firefox/139.0", | |
| "Accept": "tetext/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8", | |
| "Accept-Language": "zh-CN,zh;q=0.9", | |
| "Cache-Control": "max-age=0", | |
| "Connection": "keep-alive", | |
| } | |
| response = requests.get(url, headers=headers, timeout=4) | |
| if isinstance(response, str): | |
| r = {"error": f"Request failed: {response}"} | |
| return r, r, r, r | |
| h_text = [f"{key}: {value}" for key, value in response.headers.items()] | |
| h = "\n".join(h_text) | |
| b = extract_body_text_by_string(response.text) | |
| print(response) | |
| hr, br = predict_by_text(h, b) | |
| return hr, br, h, b | |
| demo.launch() | |