Benedikt Veith commited on
Commit
62343fe
·
1 Parent(s): a360504
Files changed (2) hide show
  1. app.py +31 -15
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,25 +1,41 @@
1
  import gradio as gr
2
- from transformers import pipeline
3
  import os
 
4
 
5
- # Token aus den Space-Secrets laden
6
- token = os.getenv("HF_TOKEN")
7
 
8
- # Modelle laden
9
- model_a = pipeline("text-generation", model="username/privat-modell-1", token=token)
10
- model_b = pipeline("text-generation", model="username/privat-modell-2", token=token)
11
 
12
- def predict(text, model_choice):
13
- if model_choice == "Modell 1":
14
- return model_a(text)[0]['generated_text']
15
- else:
16
- return model_b(text)[0]['generated_text']
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- # Gradio Interface
19
  demo = gr.Interface(
20
  fn=predict,
21
- inputs=[gr.Textbox(), gr.Dropdown(["Modell 1", "Modell 2"])],
22
- outputs="text"
 
 
 
23
  )
24
 
25
- demo.launch()
 
 
1
  import gradio as gr
2
+ import torch
3
  import os
4
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
 
6
+ MODEL_ID = "patronus-protect/wolf-guard"
7
+ TOKEN = os.getenv("HF_TOKEN")
8
 
9
+ # Modell und Tokenizer beim Start der App laden
10
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=TOKEN)
11
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, token=TOKEN)
12
 
13
+ def predict(text):
14
+ label_names = {0: "benign", 1: "attack"}
15
+
16
+ # Text verarbeiten
17
+ enc = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
18
+
19
+ with torch.no_grad():
20
+ logits = model(**enc).logits
21
+ probs = torch.softmax(logits, dim=-1).squeeze()
22
+
23
+ pred_id = int(probs.argmax())
24
+ label = label_names.get(pred_id, f"ID {pred_id}")
25
+
26
+ # Rückgabe als Dictionary für das Gradio "Label"-Feld
27
+ # Zeigt die Wahrscheinlichkeiten beider Klassen an
28
+ return {label_names[i]: float(probs[i]) for i in range(len(label_names))}
29
 
30
+ # Gradio Interface Setup
31
  demo = gr.Interface(
32
  fn=predict,
33
+ inputs=gr.Textbox(label="Eingabe Text", placeholder="Text hier einfügen...", lines=3),
34
+ outputs=gr.Label(label="Klassifizierung"),
35
+ title="Sicherheits-Check",
36
+ description="Dieses Modell erkennt Angriffe in Texten.",
37
+ allow_flagging="never"
38
  )
39
 
40
+ if __name__ == "__main__":
41
+ demo.launch()
requirements.txt CHANGED
@@ -1 +1,2 @@
1
- transformers
 
 
1
+ transformers
2
+ torch