darkQibit commited on
Commit
b5cd69e
·
verified ·
1 Parent(s): 951b174

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ import gradio as gr
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
+ from peft import PeftModel
6
+
7
+ MODEL_ID = "ruSpamModels/ruSpam-Qwen-0.5B-50k"
8
+
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+
11
+ base_model = AutoModelForCausalLM.from_pretrained(
12
+ "Qwen/Qwen2.5-0.5B-Instruct",
13
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
14
+ device_map=device,
15
+ trust_remote_code=True,
16
+ )
17
+
18
+ model = PeftModel.from_pretrained(base_model, MODEL_ID)
19
+ model.eval()
20
+
21
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
22
+
23
+ def classify(message):
24
+ prompt = (
25
+ "You are a spam classifier.\n"
26
+ "Answer with one word: spam or ham.\n\n"
27
+ f"Message:\n{message}\n\n"
28
+ "Answer:"
29
+ )
30
+
31
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
32
+
33
+ start = time.time()
34
+ with torch.no_grad():
35
+ out = model.generate(
36
+ **inputs,
37
+ max_new_tokens=1,
38
+ do_sample=False,
39
+ temperature=0.01,
40
+ pad_token_id=tokenizer.eos_token_id,
41
+ eos_token_id=tokenizer.eos_token_id,
42
+ )
43
+
44
+ elapsed = (time.time() - start) * 1000
45
+
46
+ new_token_id = out[0, inputs["input_ids"].shape[1]]
47
+ answer = tokenizer.decode(new_token_id).strip().lower()
48
+
49
+ if answer.startswith("spam"):
50
+ label = "SPAM"
51
+ elif answer.startswith("ham"):
52
+ label = "HAM"
53
+ else:
54
+ label = "UNKNOWN"
55
+
56
+ return f"{label} ({elapsed:.1f} ms)"
57
+
58
+ iface = gr.Interface(
59
+ fn=classify,
60
+ inputs=gr.Textbox(lines=4, placeholder="Введите сообщение"),
61
+ outputs=gr.Textbox(label="Результат"),
62
+ title="ruSpam Qwen 0.5B",
63
+ description="Классификация сообщений: SPAM / HAM",
64
+ )
65
+
66
+ if __name__ == "__main__":
67
+ iface.launch()