e1732a364fed commited on
Commit
3a60eea
·
1 Parent(s): 26b720e

basic feature

Browse files
Files changed (2) hide show
  1. .gitignore +3 -0
  2. app.py +116 -3
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ __pycache__/
2
+ .venv/
3
+ flagged/
app.py CHANGED
@@ -1,9 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
- def greet(name):
5
- return "Hello " + name + "!!"
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
9
  demo.launch()
 
1
+ def extract_body_text_by_string(input_string, max_len=512):
2
+ string_length = len(input_string)
3
+
4
+ if string_length <= max_len:
5
+ return input_string.strip()
6
+
7
+ chunk_size = max_len // 3 # 三等分
8
+ positions = [0, string_length // 2, string_length - chunk_size] # 头、中、尾
9
+
10
+ extracted_text = []
11
+
12
+ for pos in positions:
13
+ text_chunk = input_string[pos : pos + chunk_size]
14
+ extracted_text.append(text_chunk.strip())
15
+
16
+ return "".join(extracted_text)
17
+
18
+
19
+ import torch
20
+
21
+ DEVICE = torch.device(
22
+ "cuda"
23
+ if torch.cuda.is_available()
24
+ else "mps" if torch.backends.mps.is_available() else "cpu"
25
+ )
26
+
27
+ import numpy as np
28
+
29
+
30
+ def predict_text(text, model, tokenizer, max_len, device=DEVICE):
31
+ encoding = tokenizer.encode_plus(
32
+ text,
33
+ add_special_tokens=True,
34
+ max_length=max_len,
35
+ padding="max_length",
36
+ truncation=True,
37
+ return_attention_mask=True,
38
+ return_tensors="pt",
39
+ )
40
+
41
+ input_ids = encoding["input_ids"].to(device)
42
+ attention_mask = encoding["attention_mask"].to(device)
43
+
44
+ with torch.no_grad():
45
+ outputs = model(input_ids, attention_mask=attention_mask)
46
+ # print("outputs", outputs)
47
+ logits = outputs.logits.cpu().numpy()
48
+ pred = np.argmax(logits, axis=1)[0]
49
+
50
+ return "ok" if pred == 0 else "ban"
51
+
52
+
53
  import gradio as gr
54
 
55
+ print("loading models...")
56
+ from transformers import (
57
+ BertTokenizer,
58
+ BertForSequenceClassification,
59
+ )
60
+
61
+ head_tokenizer = BertTokenizer.from_pretrained(
62
+ f"e1732a364fed/bert-geosite-classification-head-v1"
63
+ )
64
+ body_tokenizer = BertTokenizer.from_pretrained(
65
+ f"e1732a364fed/bert-geosite-classification-body-v1"
66
+ )
67
+ head_model = BertForSequenceClassification.from_pretrained(
68
+ f"e1732a364fed/bert-geosite-classification-head-v1"
69
+ ).to(DEVICE)
70
+ body_model = BertForSequenceClassification.from_pretrained(
71
+ f"e1732a364fed/bert-geosite-classification-body-v1"
72
+ ).to(DEVICE)
73
+
74
+ head_model.eval()
75
+ body_model.eval()
76
+
77
+ print("loaded models...")
78
+
79
+
80
+ def func(head, body):
81
+ print("predicting head")
82
+
83
+ h_result = predict_text(head, head_model, head_tokenizer, 512)
84
+ print("predicting body")
85
+
86
+ b_result = predict_text(body, body_model, body_tokenizer, 512)
87
+ print("prediction done")
88
+
89
+ return h_result, b_result
90
 
 
 
91
 
92
+ demo = gr.Interface(
93
+ fn=func,
94
+ inputs=[
95
+ gr.Textbox(
96
+ label="Head",
97
+ info="http response head",
98
+ lines=6,
99
+ max_lines=10000000000,
100
+ value="200 OK",
101
+ ),
102
+ gr.Textbox(
103
+ label="Body",
104
+ info="http response body",
105
+ lines=6,
106
+ max_lines=10000000000,
107
+ value="<body>The quick brown fox jumped over the lazy dogs.</body>",
108
+ ),
109
+ ],
110
+ outputs=[
111
+ gr.Textbox(
112
+ label="Head",
113
+ info="http response prediction",
114
+ ),
115
+ gr.Textbox(
116
+ label="Body",
117
+ info="http response prediction",
118
+ ),
119
+ ],
120
+ )
121
 
 
122
  demo.launch()