AlephBeth-AI commited on
Commit
1c93979
·
verified ·
1 Parent(s): c020e5c

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +40 -48
app.py CHANGED
@@ -14,9 +14,7 @@ import torch
14
  import numpy as np
15
  import plotly.graph_objects as go
16
  import plotly.io as pio
17
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
18
-
19
- from precompute import precompute_all, is_cached
20
 
21
  # ---------------------------------------------------------------------------
22
  # Logging
@@ -64,29 +62,49 @@ CATEGORY_LABELS_FR = {
64
  }
65
 
66
  # ---------------------------------------------------------------------------
67
- # Load model for real-time analysis
 
68
  # ---------------------------------------------------------------------------
69
  MODEL_ID = "meta-llama/Llama-Prompt-Guard-2-86M"
 
70
 
71
- logger.info("Loading model %s ...", MODEL_ID)
72
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
73
- model = AutoModelForSequenceClassification.from_pretrained(
74
- MODEL_ID, output_hidden_states=True
75
- )
76
- model.eval()
77
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
78
- model.to(DEVICE)
79
- logger.info("Model loaded on %s", DEVICE)
 
 
 
 
 
 
 
 
80
 
81
- LABELS = ["Benign", "Malicious"]
82
 
83
  # ---------------------------------------------------------------------------
84
- # Load precomputed t-SNE data
85
  # ---------------------------------------------------------------------------
86
- logger.info("Loading precomputed embeddings & t-SNE...")
87
- cached_data = precompute_all()
88
- TSNE_COORDS = cached_data["tsne_2d"]
89
- METADATA = cached_data["metadata"]
 
 
 
 
 
 
 
 
 
 
 
90
  logger.info("Loaded %d points for visualization", len(METADATA))
91
 
92
  ALL_TEXTS = [m["text"] for m in METADATA]
@@ -108,10 +126,11 @@ for i, m in enumerate(METADATA):
108
  # Analysis function
109
  # ---------------------------------------------------------------------------
110
  def analyze_prompt(text: str):
111
- """Run Llama Prompt Guard 2 on a single prompt."""
112
  if not text or not text.strip():
113
  return {}, 0.0
114
 
 
115
  inputs = tokenizer(
116
  text, return_tensors="pt", truncation=True, max_length=512, padding=True
117
  ).to(DEVICE)
@@ -618,31 +637,4 @@ with gr.Blocks(
618
  inputs=[manual_input],
619
  outputs=[result_html, risk_md],
620
  )
621
- manual_input.submit(
622
- fn=on_manual_analyze,
623
- inputs=[manual_input],
624
- outputs=[result_html, risk_md],
625
- )
626
-
627
- # Inject Plotly click handler JS
628
- demo.load(fn=None, inputs=None, outputs=None, js=PLOTLY_CLICK_JS)
629
-
630
- # Footer
631
- gr.Markdown(
632
- """
633
- ---
634
- <div style="text-align:center; color:#64748b; font-size:0.8em;">
635
- <strong>GuardLLM</strong> — Visualiseur de sécurité des prompts<br>
636
- Modèle : <a href="https://huggingface.co/meta-llama/Llama-Prompt-Guard-2-86M">
637
- Llama Prompt Guard 2 (86M)</a> par Meta &bull;
638
- Dataset : <a href="https://huggingface.co/datasets/neuralchemy/Prompt-injection-dataset">
639
- neuralchemy/Prompt-injection-dataset</a>
640
- </div>
641
- """
642
- )
643
-
644
-
645
- logger.info("Gradio app built. Ready to launch.")
646
-
647
- if __name__ == "__main__":
648
- demo.launch()
 
14
  import numpy as np
15
  import plotly.graph_objects as go
16
  import plotly.io as pio
17
+ from pathlib import Path
 
 
18
 
19
  # ---------------------------------------------------------------------------
20
  # Logging
 
62
  }
63
 
64
  # ---------------------------------------------------------------------------
65
+ # Lazy-loaded risk classifier (Llama Prompt Guard 2)
66
+ # Loaded on first use so Space startup is fast.
67
  # ---------------------------------------------------------------------------
68
  MODEL_ID = "meta-llama/Llama-Prompt-Guard-2-86M"
69
+ LABELS = ["Benign", "Malicious"]
70
 
71
+ _classifier = {"tokenizer": None, "model": None, "device": None}
72
+
73
+
74
+ def get_classifier():
75
+ if _classifier["model"] is None:
76
+ logger.info("Lazy-loading Llama Prompt Guard 2...")
77
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
78
+ tok = AutoTokenizer.from_pretrained(MODEL_ID)
79
+ mdl = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
80
+ mdl.eval()
81
+ dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
82
+ mdl.to(dev)
83
+ _classifier["tokenizer"] = tok
84
+ _classifier["model"] = mdl
85
+ _classifier["device"] = dev
86
+ logger.info("Classifier loaded on %s", dev)
87
+ return _classifier["tokenizer"], _classifier["model"], _classifier["device"]
88
 
 
89
 
90
  # ---------------------------------------------------------------------------
91
+ # Load precomputed t-SNE data (fail fast if cache missing)
92
  # ---------------------------------------------------------------------------
93
+ CACHE_DIR = Path(__file__).parent / "cache"
94
+ CACHE_FILE = CACHE_DIR / "embeddings_tsne.npz"
95
+ META_FILE = CACHE_DIR / "metadata.json"
96
+
97
+ logger.info("Loading precomputed t-SNE cache from %s", CACHE_DIR)
98
+ if not CACHE_FILE.exists() or not META_FILE.exists():
99
+ raise RuntimeError(
100
+ f"Cache files not found in {CACHE_DIR}. "
101
+ f"Run `python precompute.py` first (or ensure cache/ is committed to the Space)."
102
+ )
103
+
104
+ _npz = np.load(CACHE_FILE)
105
+ TSNE_COORDS = _npz["tsne_2d"]
106
+ with open(META_FILE, "r", encoding="utf-8") as f:
107
+ METADATA = json.load(f)
108
  logger.info("Loaded %d points for visualization", len(METADATA))
109
 
110
  ALL_TEXTS = [m["text"] for m in METADATA]
 
126
  # Analysis function
127
  # ---------------------------------------------------------------------------
128
  def analyze_prompt(text: str):
129
+ """Run Llama Prompt Guard 2 on a single prompt (lazy-loads on first call)."""
130
  if not text or not text.strip():
131
  return {}, 0.0
132
 
133
+ tokenizer, model, DEVICE = get_classifier()
134
  inputs = tokenizer(
135
  text, return_tensors="pt", truncation=True, max_length=512, padding=True
136
  ).to(DEVICE)
 
637
  inputs=[manual_input],
638
  outputs=[result_html, risk_md],
639
  )
640
+ manua