darkQibit commited on
Commit
d01ca98
·
verified ·
1 Parent(s): b5cd69e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -8
app.py CHANGED
@@ -1,9 +1,11 @@
 
1
  import time
2
  import torch
3
  import gradio as gr
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
  from peft import PeftModel
6
 
 
7
  MODEL_ID = "ruSpamModels/ruSpam-Qwen-0.5B-50k"
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -13,12 +15,20 @@ base_model = AutoModelForCausalLM.from_pretrained(
13
  torch_dtype=torch.float16 if device == "cuda" else torch.float32,
14
  device_map=device,
15
  trust_remote_code=True,
 
16
  )
17
 
18
- model = PeftModel.from_pretrained(base_model, MODEL_ID)
 
 
 
 
19
  model.eval()
20
 
21
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
 
 
 
22
 
23
  def classify(message):
24
  prompt = (
@@ -42,7 +52,6 @@ def classify(message):
42
  )
43
 
44
  elapsed = (time.time() - start) * 1000
45
-
46
  new_token_id = out[0, inputs["input_ids"].shape[1]]
47
  answer = tokenizer.decode(new_token_id).strip().lower()
48
 
@@ -57,11 +66,9 @@ def classify(message):
57
 
58
  iface = gr.Interface(
59
  fn=classify,
60
- inputs=gr.Textbox(lines=4, placeholder="Введите сообщение"),
61
- outputs=gr.Textbox(label="Результат"),
62
  title="ruSpam Qwen 0.5B",
63
- description="Классификация сообщений: SPAM / HAM",
64
  )
65
 
66
- if __name__ == "__main__":
67
- iface.launch()
 
1
+ import os
2
  import time
3
  import torch
4
  import gradio as gr
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
  from peft import PeftModel
7
 
8
+ HF_TOKEN = os.getenv("HF_TOKEN")
9
  MODEL_ID = "ruSpamModels/ruSpam-Qwen-0.5B-50k"
10
 
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
15
  torch_dtype=torch.float16 if device == "cuda" else torch.float32,
16
  device_map=device,
17
  trust_remote_code=True,
18
+ token=HF_TOKEN,
19
  )
20
 
21
+ model = PeftModel.from_pretrained(
22
+ base_model,
23
+ MODEL_ID,
24
+ token=HF_TOKEN,
25
+ )
26
  model.eval()
27
 
28
+ tokenizer = AutoTokenizer.from_pretrained(
29
+ MODEL_ID,
30
+ token=HF_TOKEN,
31
+ )
32
 
33
  def classify(message):
34
  prompt = (
 
52
  )
53
 
54
  elapsed = (time.time() - start) * 1000
 
55
  new_token_id = out[0, inputs["input_ids"].shape[1]]
56
  answer = tokenizer.decode(new_token_id).strip().lower()
57
 
 
66
 
67
  iface = gr.Interface(
68
  fn=classify,
69
+ inputs=gr.Textbox(lines=4),
70
+ outputs=gr.Textbox(),
71
  title="ruSpam Qwen 0.5B",
 
72
  )
73
 
74
+ iface.launch()