jaivsh commited on
Commit
ee1f3df
·
0 Parent(s):

add Image_prompt_detection project

Browse files
.gitignore ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ sentinel_env/
2
+ venv/
3
+ .env/
4
+
5
+ __pycache__/
6
+ *.pyc
7
+ *.pyo
8
+
9
+ .paddlex/
10
+ .cache/
11
+
12
+ .DS_Store
13
+ .vscode/
README.md ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Multimodal Visual Security Engine (EasyOCR + ONNX DeBERTa + CLIP + BLIP)
2
+
3
+ ## System Architecture
4
+
5
+ ```mermaid
6
+ graph TD
7
+ Input[Input: Image/Video Frame] --> Split{Parallel Process}
8
+
9
+ %% Engine D Logic
10
+ Split --> EngineD[Engine D: Prompt Injection]
11
+ EngineD --> OCR[EasyOCR: Extract Text]
12
+ OCR --> Norm[Normalization Layer]
13
+ Norm --> InjectModel[DeBERTa Prompt Injection (ONNX)]
14
+ InjectModel --> ThreatCheck{Threat Dictionary (aux)}
15
+ ThreatCheck --> RiskScore[Risk Score + Reason]
16
+
17
+ %% Engine E Logic
18
+ Split --> EngineE[Engine E: Cross-Modal]
19
+ EngineE --> BLIP[BLIP: Image Caption]
20
+ InputAudio[Input: Audio Transcript] --> CLIP_Text[CLIP Text Encoder]
21
+ EngineE --> CLIP_Img[CLIP Image Encoder]
22
+ CLIP_Text --> Cosine[Cosine Similarity Calc]
23
+ CLIP_Img --> Cosine
24
+ Cosine --> Threshold{Is Score < 0.18?}
25
+ Threshold -- Yes --> Mismatch[Status: MISMATCH - Deepfake]
26
+ Threshold -- No --> Match[Status: MATCH - Genuine]
27
+ ```
28
+
29
+ **Engine D (Visual Prompt Injection)**
30
+ OCR-based text extraction + ML classification. EasyOCR extracts visible or hidden text (with CLAHE + Otsu binarization for low-contrast regions), a normalization layer de-obfuscates tokens, and a DeBERTa prompt‑injection classifier (ONNX runtime) scores risk. A small threat dictionary is used as auxiliary evidence in the reason string, not as the primary detector.
31
+
32
+ **Engine E (Cross-Modal Consistency)**
33
+ Semantic-based (not OCR). CLIP (ViT-B/32) embeds both the video frame and the audio transcript into a shared vector space to verify that the visual context matches the spoken context. BLIP generates an image caption and we compare it with OCR text to detect prompt/scene misalignment.
34
+
35
+ ## Quick Start
36
+
37
+ ```bash
38
+ # Install dependencies
39
+ pip install -r requirements.txt
40
+
41
+ # Run the Visual Engine Test
42
+ python -m src.engines.visual_engine
43
+ ```
app.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import httpx
2
+ import streamlit as st
3
+
4
+
5
+ st.set_page_config(page_title="Visual Security Engine", layout="wide")
6
+ st.title("Visual Security Engine Demo")
7
+
8
+ uploaded = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg", "webp"])
9
+ transcript = st.text_area("Audio transcript (optional)", value="a cat sitting on a ledge")
10
+
11
+ with st.sidebar:
12
+ st.header("API Settings")
13
+ mode = st.selectbox("API mode", ["gateway", "split"], index=0)
14
+ gateway_url = st.text_input("Gateway URL", value="http://localhost:8000")
15
+ engine_d_url = st.text_input("Engine D URL", value="http://localhost:8001")
16
+ engine_e_url = st.text_input("Engine E URL", value="http://localhost:8002")
17
+ st.caption("Gateway mode calls a single API. Split mode calls D/E separately.")
18
+ st.header("Performance")
19
+ run_ocr = st.checkbox("Show OCR output", value=True)
20
+ run_injection = st.checkbox("Run prompt-injection model", value=True)
21
+ run_cross_modal = st.checkbox("Run cross-modal check", value=True)
22
+ run_caption = st.checkbox("Run BLIP caption alignment", value=True)
23
+ if run_injection and not run_ocr:
24
+ st.info("OCR is required for prompt-injection. Enabling OCR display.")
25
+ run_ocr = True
26
+
27
+ run_clicked = st.button("Run analysis", type="primary")
28
+
29
+ if run_clicked and not uploaded:
30
+ st.warning("Please upload an image to continue.")
31
+
32
+ if run_clicked and uploaded:
33
+ image_bytes = uploaded.read()
34
+ st.image(image_bytes, caption="Uploaded image", use_container_width=True)
35
+
36
+ with st.spinner("Calling APIs for analysis..."):
37
+ text_payload = {}
38
+ injection_result = {"skipped": True}
39
+ cross_modal_result = {"skipped": True}
40
+
41
+ if mode == "gateway":
42
+ try:
43
+ response = httpx.post(
44
+ f"{gateway_url.rstrip('/')}/analyze",
45
+ files={"image": (uploaded.name, image_bytes, uploaded.type or "image/jpeg")},
46
+ data={
47
+ "audio_transcript": transcript,
48
+ "run_caption": str(run_caption).lower(),
49
+ "deep": str(run_injection).lower(),
50
+ },
51
+ timeout=300,
52
+ )
53
+ response.raise_for_status()
54
+ except Exception as exc:
55
+ st.error("Gateway API call failed. Is it running on the configured URL?")
56
+ st.exception(exc)
57
+ st.stop()
58
+ payload = response.json()
59
+ text_payload = payload.get("ocr", {})
60
+ injection_result = payload.get("injection", {})
61
+ cross_modal_result = payload.get("cross_modal", {})
62
+ ocr_vs_image = payload.get("ocr_vs_image", {})
63
+ caption_alignment = payload.get("caption_alignment", {})
64
+ final_score = payload.get("final_score")
65
+ else:
66
+ if run_injection or run_ocr:
67
+ try:
68
+ response_d = httpx.post(
69
+ f"{engine_d_url.rstrip('/')}/analyze_d",
70
+ files={"image": (uploaded.name, image_bytes, uploaded.type or "image/jpeg")},
71
+ data={"deep": str(run_injection).lower()},
72
+ timeout=300,
73
+ )
74
+ response_d.raise_for_status()
75
+ except Exception as exc:
76
+ st.error("Engine D API call failed. Is it running on the configured URL?")
77
+ st.exception(exc)
78
+ else:
79
+ payload_d = response_d.json()
80
+ text_payload = payload_d.get("ocr", {})
81
+ injection_result = payload_d.get("injection", {})
82
+
83
+ if run_cross_modal:
84
+ try:
85
+ response_e = httpx.post(
86
+ f"{engine_e_url.rstrip('/')}/analyze_e",
87
+ files={"image": (uploaded.name, image_bytes, uploaded.type or "image/jpeg")},
88
+ data={
89
+ "audio_transcript": transcript,
90
+ "ocr_text": text_payload.get("normalized_text", ""),
91
+ "run_caption": str(run_caption).lower(),
92
+ },
93
+ timeout=300,
94
+ )
95
+ response_e.raise_for_status()
96
+ except Exception as exc:
97
+ st.error("Engine E API call failed. Is it running on the configured URL?")
98
+ st.exception(exc)
99
+ else:
100
+ payload_e = response_e.json()
101
+ cross_modal_result = payload_e.get("cross_modal", {})
102
+ ocr_vs_image = payload_e.get("ocr_vs_image", {})
103
+ caption_alignment = payload_e.get("caption_alignment", {})
104
+ else:
105
+ ocr_vs_image = {"skipped": True}
106
+ caption_alignment = {"skipped": True}
107
+ final_score = None
108
+
109
+ col1, col2 = st.columns(2)
110
+ with col1:
111
+ st.subheader("OCR Output")
112
+ if not run_ocr:
113
+ st.info("OCR display disabled.")
114
+ else:
115
+ st.text_area("Raw text", value=text_payload.get("raw_text", ""), height=150)
116
+ st.text_area(
117
+ "Normalized text", value=text_payload.get("normalized_text", ""), height=120
118
+ )
119
+
120
+ with col2:
121
+ st.subheader("Engine D: Prompt Injection")
122
+ st.json(injection_result)
123
+ st.subheader("Engine E: Cross-Modal Consistency")
124
+ st.json(cross_modal_result)
125
+ st.subheader("OCR vs Image (CLIP)")
126
+ st.json(ocr_vs_image)
127
+ st.subheader("Caption Alignment (BLIP)")
128
+ st.json(caption_alignment)
129
+ if final_score is not None:
130
+ st.subheader("Final Risk Score")
131
+ st.metric("final_score", final_score)
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ easyocr
2
+ opencv-python
3
+ sentence-transformers
4
+ transformers
5
+ httpx
6
+ streamlit
7
+ optimum
8
+ onnxruntime
9
+ fastapi
10
+ uvicorn
11
+ python-multipart
src/api/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """API package for Sentinel-X."""
src/api/engine_d_server.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from fastapi import FastAPI, File, Form, UploadFile
4
+
5
+ from src.engines.visual_engine import PromptInjectionEngine, THREAT_DICTIONARY
6
+
7
+ app = FastAPI(title="Engine D (Prompt Injection) API")
8
+ _ENGINE: PromptInjectionEngine | None = None
9
+
10
+
11
+ @app.on_event("startup")
12
+ def load_engine() -> None:
13
+ global _ENGINE
14
+ if _ENGINE is None:
15
+ _ENGINE = PromptInjectionEngine(use_onnx=True)
16
+
17
+
18
+ @app.get("/")
19
+ def health_check() -> dict:
20
+ return {"status": "ok", "engine": "d"}
21
+
22
+
23
+ @app.post("/analyze_d")
24
+ async def analyze_engine_d(
25
+ image: UploadFile = File(...),
26
+ deep: bool = Form(True),
27
+ ) -> dict:
28
+ if _ENGINE is None:
29
+ load_engine()
30
+ engine = _ENGINE
31
+ image_bytes = await image.read()
32
+ text_payload = engine.extract_text(image_bytes)
33
+ normalized_text = text_payload["normalized_text"]
34
+ matched = [phrase for phrase in THREAT_DICTIONARY if phrase in normalized_text]
35
+ scores = [score for _, score in text_payload.get("scored", [])]
36
+ ocr_confidence = float(sum(scores) / len(scores)) if scores else 0.5
37
+ if deep:
38
+ injection_result = engine.detect_injection_from_text(normalized_text, matched_phrases=matched)
39
+ else:
40
+ injection_result = {
41
+ "is_threat": bool(matched),
42
+ "risk_score": 0.9 if matched else 0.1,
43
+ "reason": "FastPathRegex",
44
+ }
45
+ return {
46
+ "ocr": {**text_payload, "ocr_confidence": round(ocr_confidence, 3)},
47
+ "injection": injection_result,
48
+ }
49
+
50
+
51
+ @app.post("/analyze_d_batch")
52
+ async def analyze_engine_d_batch(
53
+ images: List[UploadFile] = File(...),
54
+ deep: bool = Form(True),
55
+ ) -> dict:
56
+ if _ENGINE is None:
57
+ load_engine()
58
+ engine = _ENGINE
59
+ normalized_batch: List[str] = []
60
+ ocr_payloads: List[dict] = []
61
+ matched_batch: List[List[str]] = []
62
+
63
+ for img in images:
64
+ image_bytes = await img.read()
65
+ payload = engine.extract_text(image_bytes)
66
+ scores = [score for _, score in payload.get("scored", [])]
67
+ payload["ocr_confidence"] = round(float(sum(scores) / len(scores)) if scores else 0.5, 3)
68
+ ocr_payloads.append(payload)
69
+ normalized_text = payload["normalized_text"]
70
+ normalized_batch.append(normalized_text)
71
+ matched_batch.append([phrase for phrase in THREAT_DICTIONARY if phrase in normalized_text])
72
+
73
+ results: List[dict] = []
74
+ if deep:
75
+ # Batch the DeBERTa pipeline to utilize parallelism.
76
+ classifier = engine._get_injection_classifier()
77
+ classifications = classifier(normalized_batch, top_k=1)
78
+ for idx, classification in enumerate(classifications):
79
+ label = str(classification.get("label", "")).upper()
80
+ score = float(classification.get("score", 0.0))
81
+ is_injection = "1" in label or "INJECTION" in label
82
+ risk_score = score if is_injection else 1.0 - score
83
+ reason = f"Model={label or 'UNKNOWN'}; model_score={score:.3f}"
84
+ if matched_batch[idx]:
85
+ reason += f"; matched_phrases={', '.join(sorted(set(matched_batch[idx])))}"
86
+ results.append(
87
+ {"is_threat": bool(is_injection), "risk_score": round(risk_score, 3), "reason": reason}
88
+ )
89
+ else:
90
+ for matched in matched_batch:
91
+ results.append(
92
+ {
93
+ "is_threat": bool(matched),
94
+ "risk_score": 0.9 if matched else 0.1,
95
+ "reason": "FastPathRegex",
96
+ }
97
+ )
98
+
99
+ return {"ocr": ocr_payloads, "injection": results}
src/api/engine_e_server.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, Form, UploadFile
2
+
3
+ from src.engines.visual_engine import CrossModalEngine
4
+
5
+ app = FastAPI(title="Engine E (Cross-Modal) API")
6
+ _ENGINE: CrossModalEngine | None = None
7
+
8
+
9
+ @app.on_event("startup")
10
+ def load_engine() -> None:
11
+ global _ENGINE
12
+ if _ENGINE is None:
13
+ _ENGINE = CrossModalEngine()
14
+
15
+
16
+ @app.get("/")
17
+ def health_check() -> dict:
18
+ return {"status": "ok", "engine": "e"}
19
+
20
+
21
+ @app.post("/analyze_e")
22
+ async def analyze_engine_e(
23
+ image: UploadFile = File(...),
24
+ audio_transcript: str = Form(""),
25
+ ocr_text: str = Form(""),
26
+ run_caption: bool = Form(True),
27
+ ) -> dict:
28
+ if _ENGINE is None:
29
+ load_engine()
30
+ engine = _ENGINE
31
+ image_bytes = await image.read()
32
+ cross_modal_result = engine.check_cross_modal(image_bytes, audio_transcript)
33
+ ocr_vs_image = engine.check_ocr_vs_image(image_bytes, ocr_text) if ocr_text else {
34
+ "is_mismatch": False,
35
+ "consistency_score": 0.0,
36
+ }
37
+ caption_alignment = (
38
+ engine.check_caption_alignment(image_bytes, ocr_text) if run_caption else {"caption": "", "alignment_score": 0.0}
39
+ )
40
+ return {
41
+ "cross_modal": cross_modal_result,
42
+ "ocr_vs_image": ocr_vs_image,
43
+ "caption_alignment": caption_alignment,
44
+ }
src/api/gateway_server.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import httpx
4
+ from fastapi import FastAPI, File, Form, UploadFile
5
+
6
+ app = FastAPI(title="Visual Security Engine Gateway API")
7
+
8
+
9
+ def _engine_d_url() -> str:
10
+ return os.environ.get("ENGINE_D_URL", "http://localhost:8001").rstrip("/")
11
+
12
+
13
+ def _engine_e_url() -> str:
14
+ return os.environ.get("ENGINE_E_URL", "http://localhost:8002").rstrip("/")
15
+
16
+
17
+ def _clamp(value: float) -> float:
18
+ return max(0.0, min(1.0, value))
19
+
20
+
21
+ @app.get("/")
22
+ def health_check() -> dict:
23
+ return {"status": "ok", "engine": "gateway"}
24
+
25
+
26
+ @app.post("/analyze")
27
+ async def analyze(
28
+ image: UploadFile = File(...),
29
+ audio_transcript: str = Form(""),
30
+ run_caption: bool = Form(True),
31
+ deep: bool = Form(True),
32
+ ) -> dict:
33
+ image_bytes = await image.read()
34
+
35
+ async with httpx.AsyncClient(timeout=300) as client:
36
+ resp_d = await client.post(
37
+ f"{_engine_d_url()}/analyze_d",
38
+ files={"image": (image.filename, image_bytes, image.content_type or "image/jpeg")},
39
+ data={"deep": str(deep).lower()},
40
+ )
41
+ resp_d.raise_for_status()
42
+ payload_d = resp_d.json()
43
+
44
+ ocr_text = payload_d.get("ocr", {}).get("normalized_text", "")
45
+ resp_e = await client.post(
46
+ f"{_engine_e_url()}/analyze_e",
47
+ files={"image": (image.filename, image_bytes, image.content_type or "image/jpeg")},
48
+ data={
49
+ "audio_transcript": audio_transcript,
50
+ "ocr_text": ocr_text,
51
+ "run_caption": str(run_caption).lower(),
52
+ },
53
+ )
54
+ resp_e.raise_for_status()
55
+ payload_e = resp_e.json()
56
+
57
+ injection = payload_d.get("injection", {})
58
+ ocr_conf = float(payload_d.get("ocr", {}).get("ocr_confidence", 0.5))
59
+ cross_modal = payload_e.get("cross_modal", {})
60
+ ocr_vs_image = payload_e.get("ocr_vs_image", {})
61
+ caption_align = payload_e.get("caption_alignment", {})
62
+
63
+ injection_risk = float(injection.get("risk_score", 0.0))
64
+ audio_align = float(cross_modal.get("consistency_score", 0.0))
65
+ ocr_img_align = float(ocr_vs_image.get("consistency_score", 0.0))
66
+ caption_align_score = float(caption_align.get("alignment_score", 0.0))
67
+
68
+ final_score = _clamp(
69
+ 0.45 * injection_risk
70
+ + 0.15 * (1.0 - ocr_conf)
71
+ + 0.2 * (1.0 - audio_align)
72
+ + 0.1 * (1.0 - ocr_img_align)
73
+ + 0.1 * (1.0 - caption_align_score)
74
+ )
75
+
76
+ return {
77
+ "ocr": payload_d.get("ocr", {}),
78
+ "injection": injection,
79
+ "cross_modal": cross_modal,
80
+ "ocr_vs_image": ocr_vs_image,
81
+ "caption_alignment": caption_align,
82
+ "final_score": round(final_score, 3),
83
+ }
src/api/server.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, Form, UploadFile
2
+
3
+ from src.engines.visual_engine import VisualSecurityEngine
4
+
5
+ app = FastAPI(title="Visual Security Engine API")
6
+ _ENGINE: VisualSecurityEngine | None = None
7
+
8
+
9
+ @app.on_event("startup")
10
+ def load_engine() -> None:
11
+ global _ENGINE
12
+ if _ENGINE is None:
13
+ _ENGINE = VisualSecurityEngine()
14
+
15
+
16
+ @app.get("/")
17
+ def health_check() -> dict:
18
+ return {"status": "ok"}
19
+
20
+
21
+ @app.post("/analyze")
22
+ async def analyze_image(
23
+ image: UploadFile = File(...),
24
+ audio_transcript: str = Form(""),
25
+ run_ocr: bool = Form(True),
26
+ run_injection: bool = Form(True),
27
+ run_cross_modal: bool = Form(True),
28
+ ) -> dict:
29
+ if _ENGINE is None:
30
+ load_engine()
31
+ engine = _ENGINE
32
+ image_bytes = await image.read()
33
+ if run_injection:
34
+ run_ocr = True
35
+
36
+ text_payload = None
37
+ if run_ocr:
38
+ text_payload = engine.extract_text(image_bytes)
39
+
40
+ if run_injection:
41
+ injection_result = engine.detect_injection_from_text(
42
+ text_payload["normalized_text"] if text_payload else ""
43
+ )
44
+ else:
45
+ injection_result = {"skipped": True}
46
+
47
+ if run_cross_modal and audio_transcript.strip():
48
+ cross_modal_result = engine.check_cross_modal(image_bytes, audio_transcript)
49
+ elif run_cross_modal:
50
+ cross_modal_result = {"is_mismatch": True, "consistency_score": 0.0}
51
+ else:
52
+ cross_modal_result = {"skipped": True}
53
+
54
+ return {
55
+ "ocr": text_payload or {"skipped": True},
56
+ "injection": injection_result,
57
+ "cross_modal": cross_modal_result,
58
+ }
src/engines/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Engines package for Sentinel-X."""
src/engines/visual_engine.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import re
4
+ import urllib.request
5
+ from typing import Any, Dict, Iterable, List, Tuple, Union
6
+
7
+ import numpy as np
8
+ import cv2
9
+ from PIL import Image
10
+ import easyocr
11
+ from sentence_transformers import SentenceTransformer
12
+ import torch
13
+ from transformers import (
14
+ AutoModelForSequenceClassification,
15
+ AutoTokenizer,
16
+ BlipForConditionalGeneration,
17
+ BlipProcessor,
18
+ pipeline,
19
+ )
20
+
21
+ try:
22
+ from optimum.onnxruntime import ORTModelForSequenceClassification
23
+
24
+ _HAS_ORT = True
25
+ except Exception:
26
+ _HAS_ORT = False
27
+
28
+
29
+ THREAT_DICTIONARY = [
30
+ "ignore previous",
31
+ "system override",
32
+ "transfer funds",
33
+ "bypass safety",
34
+ "disable guardrails",
35
+ "override policy",
36
+ "reveal secrets",
37
+ ]
38
+
39
+
40
+ class PromptInjectionEngine:
41
+ def __init__(
42
+ self,
43
+ use_onnx: bool | None = None,
44
+ force_cpu: bool | None = None,
45
+ model_name: str | None = None,
46
+ ) -> None:
47
+ os.environ.setdefault("HF_HUB_TIMEOUT", "60")
48
+ os.environ.setdefault("HF_HUB_DOWNLOAD_TIMEOUT", "60")
49
+ os.environ.setdefault("HF_HUB_DISABLE_TELEMETRY", "1")
50
+ self._ocr: easyocr.Reader | None = None
51
+ self._injection_classifier = None
52
+ self._model_name = model_name or "protectai/deberta-v3-base-prompt-injection"
53
+ if force_cpu is None:
54
+ self._force_cpu = os.environ.get("SENTINEL_FORCE_CPU", "").lower() in {
55
+ "1",
56
+ "true",
57
+ "yes",
58
+ }
59
+ else:
60
+ self._force_cpu = force_cpu
61
+ if use_onnx is None:
62
+ self._use_onnx = os.environ.get("SENTINEL_USE_ONNX", "1") not in {"0", "false"}
63
+ else:
64
+ self._use_onnx = use_onnx
65
+
66
+ def _get_ocr(self) -> easyocr.Reader:
67
+ if self._ocr is None:
68
+ ocr_gpu = os.environ.get("SENTINEL_OCR_GPU", "1") not in {"0", "false"}
69
+ try:
70
+ self._ocr = easyocr.Reader(["en"], gpu=ocr_gpu)
71
+ except Exception:
72
+ self._ocr = easyocr.Reader(["en"], gpu=False)
73
+ return self._ocr
74
+
75
+ def _get_injection_classifier(self):
76
+ if self._injection_classifier is None:
77
+ if self._use_onnx and _HAS_ORT:
78
+ try:
79
+ tokenizer = AutoTokenizer.from_pretrained(
80
+ self._model_name, subfolder="onnx", local_files_only=True
81
+ )
82
+ model = ORTModelForSequenceClassification.from_pretrained(
83
+ self._model_name, subfolder="onnx", export=False, local_files_only=True
84
+ )
85
+ except Exception:
86
+ tokenizer = AutoTokenizer.from_pretrained(self._model_name, subfolder="onnx")
87
+ model = ORTModelForSequenceClassification.from_pretrained(
88
+ self._model_name, subfolder="onnx", export=False
89
+ )
90
+ self._injection_classifier = pipeline(
91
+ "text-classification",
92
+ model=model,
93
+ tokenizer=tokenizer,
94
+ truncation=True,
95
+ max_length=512,
96
+ )
97
+ else:
98
+ try:
99
+ tokenizer = AutoTokenizer.from_pretrained(
100
+ self._model_name, local_files_only=True
101
+ )
102
+ model = AutoModelForSequenceClassification.from_pretrained(
103
+ self._model_name, local_files_only=True
104
+ )
105
+ except Exception:
106
+ tokenizer = AutoTokenizer.from_pretrained(self._model_name)
107
+ model = AutoModelForSequenceClassification.from_pretrained(self._model_name)
108
+ device = torch.device(
109
+ "cpu"
110
+ if self._force_cpu or not torch.backends.mps.is_available()
111
+ else "mps"
112
+ )
113
+ self._injection_classifier = pipeline(
114
+ "text-classification",
115
+ model=model,
116
+ tokenizer=tokenizer,
117
+ truncation=True,
118
+ max_length=512,
119
+ device=device,
120
+ )
121
+ return self._injection_classifier
122
+
123
+ @staticmethod
124
+ def _normalize_text(text: str) -> str:
125
+ lowered = text.lower()
126
+ cleaned = re.sub(r"[^a-z0-9]+", " ", lowered)
127
+ tokens = cleaned.split()
128
+
129
+ def merge_single_letter_runs(items: Iterable[str]) -> List[str]:
130
+ merged: List[str] = []
131
+ run: List[str] = []
132
+ for token in items:
133
+ if len(token) == 1:
134
+ run.append(token)
135
+ continue
136
+ if run:
137
+ merged.append("".join(run))
138
+ run = []
139
+ merged.append(token)
140
+ if run:
141
+ merged.append("".join(run))
142
+ return merged
143
+
144
+ merged_tokens = merge_single_letter_runs(tokens)
145
+ return " ".join(merged_tokens)
146
+
147
+ @staticmethod
148
+ def _load_image_for_ocr(image: Union[str, bytes]) -> Union[str, np.ndarray]:
149
+ if isinstance(image, str):
150
+ return image
151
+ pil_image = Image.open(io.BytesIO(image)).convert("RGB")
152
+ return np.array(pil_image)
153
+
154
+ @staticmethod
155
+ def _enhance_for_hidden_text(image: np.ndarray) -> np.ndarray:
156
+ gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
157
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
158
+ enhanced = clahe.apply(gray)
159
+ _, binary = cv2.threshold(enhanced, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
160
+ return cv2.cvtColor(binary, cv2.COLOR_GRAY2RGB)
161
+
162
+ @staticmethod
163
+ def _load_image_for_clip(image: Union[str, bytes]) -> Image.Image:
164
+ if isinstance(image, str):
165
+ return Image.open(image).convert("RGB")
166
+ return Image.open(io.BytesIO(image)).convert("RGB")
167
+
168
+ @staticmethod
169
+ def _extract_ocr_text(ocr_result: List[Any]) -> Tuple[str, List[Tuple[str, float]]]:
170
+ fragments: List[str] = []
171
+ scored: List[Tuple[str, float]] = []
172
+ # EasyOCR returns: [([bbox], text, confidence), ...]
173
+ for line in ocr_result or []:
174
+ if not line or len(line) < 2:
175
+ continue
176
+ text = str(line[1])
177
+ score = float(line[2]) if len(line) > 2 and isinstance(line[2], (float, int)) else None
178
+ if text:
179
+ fragments.append(text)
180
+ if score is not None:
181
+ scored.append((text, score))
182
+ return " ".join(fragments), scored
183
+
184
+ def detect_injection(self, image: Union[str, bytes]) -> Dict[str, Any]:
185
+ text_payload = self.extract_text(image)
186
+ return self.detect_injection_from_text(
187
+ text_payload["normalized_text"],
188
+ matched_phrases=[
189
+ phrase for phrase in THREAT_DICTIONARY if phrase in text_payload["normalized_text"]
190
+ ],
191
+ )
192
+
193
+ def detect_injection_from_text(
194
+ self, normalized_text: str, matched_phrases: List[str] | None = None
195
+ ) -> Dict[str, Any]:
196
+ if not normalized_text:
197
+ return {
198
+ "is_threat": False,
199
+ "risk_score": 0.0,
200
+ "reason": "No readable text detected in image.",
201
+ }
202
+
203
+ matched = matched_phrases or [
204
+ phrase for phrase in THREAT_DICTIONARY if phrase in normalized_text
205
+ ]
206
+
207
+ try:
208
+ classifier = self._get_injection_classifier()
209
+ classification = classifier(normalized_text, top_k=1)[0]
210
+ label = str(classification.get("label", "")).upper()
211
+ score = float(classification.get("score", 0.0))
212
+ is_injection = "1" in label or "INJECTION" in label
213
+ risk_score = score if is_injection else 1.0 - score
214
+ reason_parts = [
215
+ f"Model={label or 'UNKNOWN'}",
216
+ f"model_score={score:.3f}",
217
+ ]
218
+ except Exception:
219
+ is_injection = bool(matched)
220
+ risk_score = 0.9 if matched else 0.1
221
+ reason_parts = ["Model=FALLBACK", "model_score=0.0"]
222
+ if matched:
223
+ reason_parts.append(f"matched_phrases={', '.join(sorted(set(matched)))}")
224
+
225
+ return {
226
+ "is_threat": bool(is_injection),
227
+ "risk_score": round(risk_score, 3),
228
+ "reason": "; ".join(reason_parts),
229
+ }
230
+
231
+ def extract_text(self, image: Union[str, bytes]) -> Dict[str, Any]:
232
+ ocr_input = self._load_image_for_ocr(image)
233
+ reader = self._get_ocr()
234
+ if isinstance(ocr_input, str):
235
+ ocr_result = reader.readtext(ocr_input)
236
+ raw_text, scored = self._extract_ocr_text(ocr_result)
237
+ normalized = self._normalize_text(raw_text)
238
+ else:
239
+ base_result = reader.readtext(ocr_input)
240
+ enhanced_image = self._enhance_for_hidden_text(ocr_input)
241
+ enhanced_result = reader.readtext(enhanced_image)
242
+ raw_text_base, scored_base = self._extract_ocr_text(base_result)
243
+ raw_text_enh, scored_enh = self._extract_ocr_text(enhanced_result)
244
+ raw_text = " ".join([raw_text_base, raw_text_enh]).strip()
245
+ scored = scored_base + scored_enh
246
+ normalized = self._normalize_text(raw_text)
247
+ return {
248
+ "raw_text": raw_text,
249
+ "normalized_text": normalized,
250
+ "scored": scored,
251
+ }
252
+
253
+
254
+ class CrossModalEngine:
255
+ def __init__(self, clip_model: str | None = None, caption_model: str | None = None) -> None:
256
+ self._clip = SentenceTransformer(
257
+ clip_model or os.environ.get("SENTINEL_CLIP_MODEL", "clip-ViT-B-32")
258
+ )
259
+ self._captioner = None
260
+ self._caption_model = caption_model or os.environ.get(
261
+ "SENTINEL_BLIP_MODEL", "Salesforce/blip-image-captioning-base"
262
+ )
263
+
264
+ @staticmethod
265
+ def _load_image_for_clip(image: Union[str, bytes]) -> Image.Image:
266
+ if isinstance(image, str):
267
+ return Image.open(image).convert("RGB")
268
+ return Image.open(io.BytesIO(image)).convert("RGB")
269
+
270
+ def _get_captioner(self):
271
+ if self._captioner is None:
272
+ # Use BLIP processor + model directly to avoid pipeline task mismatches.
273
+ processor = BlipProcessor.from_pretrained(self._caption_model)
274
+ model = BlipForConditionalGeneration.from_pretrained(self._caption_model)
275
+ device = os.environ.get("SENTINEL_BLIP_DEVICE", "cpu")
276
+ model.to(device)
277
+ self._captioner = (processor, model, device)
278
+ return self._captioner
279
+
280
+ def check_cross_modal(self, image: Union[str, bytes], audio_transcript: str) -> Dict[str, Any]:
281
+ if not audio_transcript:
282
+ return {"is_mismatch": True, "consistency_score": 0.0}
283
+
284
+ pil_image = self._load_image_for_clip(image)
285
+ image_emb = self._clip.encode([pil_image], normalize_embeddings=True)
286
+ text_emb = self._clip.encode([audio_transcript], normalize_embeddings=True)
287
+ similarity = float(np.dot(image_emb[0], text_emb[0]))
288
+
289
+ return {
290
+ "is_mismatch": similarity < 0.18,
291
+ "consistency_score": round(similarity, 4),
292
+ }
293
+
294
+ def check_ocr_vs_image(self, image: Union[str, bytes], ocr_text: str) -> Dict[str, Any]:
295
+ if not ocr_text:
296
+ return {"is_mismatch": False, "consistency_score": 0.0}
297
+ pil_image = self._load_image_for_clip(image)
298
+ image_emb = self._clip.encode([pil_image], normalize_embeddings=True)
299
+ text_emb = self._clip.encode([ocr_text], normalize_embeddings=True)
300
+ similarity = float(np.dot(image_emb[0], text_emb[0]))
301
+ return {
302
+ "is_mismatch": similarity < 0.18,
303
+ "consistency_score": round(similarity, 4),
304
+ }
305
+
306
+ def check_caption_alignment(self, image: Union[str, bytes], ocr_text: str) -> Dict[str, Any]:
307
+ if not ocr_text:
308
+ return {"caption": "", "alignment_score": 0.0}
309
+ pil_image = self._load_image_for_clip(image)
310
+ processor, model, device = self._get_captioner()
311
+ inputs = processor(images=pil_image, return_tensors="pt").to(device)
312
+ output_ids = model.generate(**inputs, max_new_tokens=30)
313
+ caption = processor.decode(output_ids[0], skip_special_tokens=True)
314
+ text_emb = self._clip.encode([ocr_text], normalize_embeddings=True)
315
+ caption_emb = self._clip.encode([caption], normalize_embeddings=True)
316
+ similarity = float(np.dot(text_emb[0], caption_emb[0]))
317
+ return {"caption": caption, "alignment_score": round(similarity, 4)}
318
+
319
+
320
+ class VisualSecurityEngine:
321
+ def __init__(
322
+ self,
323
+ use_onnx: bool | None = None,
324
+ force_cpu: bool | None = None,
325
+ clip_model: str | None = None,
326
+ ) -> None:
327
+ self.engine_d = PromptInjectionEngine(use_onnx=use_onnx, force_cpu=force_cpu)
328
+ self.engine_e = CrossModalEngine(clip_model=clip_model)
329
+
330
+ def extract_text(self, image: Union[str, bytes]) -> Dict[str, Any]:
331
+ return self.engine_d.extract_text(image)
332
+
333
+ def detect_injection(self, image: Union[str, bytes]) -> Dict[str, Any]:
334
+ return self.engine_d.detect_injection(image)
335
+
336
+ def detect_injection_from_text(
337
+ self, normalized_text: str, matched_phrases: List[str] | None = None
338
+ ) -> Dict[str, Any]:
339
+ return self.engine_d.detect_injection_from_text(normalized_text, matched_phrases)
340
+
341
+ def check_cross_modal(self, image: Union[str, bytes], audio_transcript: str) -> Dict[str, Any]:
342
+ return self.engine_e.check_cross_modal(image, audio_transcript)
343
+
344
+ def check_ocr_vs_image(self, image: Union[str, bytes], ocr_text: str) -> Dict[str, Any]:
345
+ return self.engine_e.check_ocr_vs_image(image, ocr_text)
346
+
347
+ def check_caption_alignment(self, image: Union[str, bytes], ocr_text: str) -> Dict[str, Any]:
348
+ return self.engine_e.check_caption_alignment(image, ocr_text)
349
+
350
+
351
+ def _download_demo_image() -> bytes:
352
+ demo_urls = [
353
+ "https://upload.wikimedia.org/wikipedia/commons/3/3a/Cat03.jpg",
354
+ "https://upload.wikimedia.org/wikipedia/commons/7/74/A-Cat.jpg",
355
+ ]
356
+ headers = {"User-Agent": "Mozilla/5.0 (Sentinel-X demo)"}
357
+ last_error: Exception | None = None
358
+ for url in demo_urls:
359
+ try:
360
+ request = urllib.request.Request(url, headers=headers)
361
+ with urllib.request.urlopen(request, timeout=20) as response:
362
+ return response.read()
363
+ except Exception as exc: # pragma: no cover - best effort demo download
364
+ last_error = exc
365
+ continue
366
+ raise RuntimeError(f"Failed to download demo image: {last_error}")
367
+
368
+
369
+ if __name__ == "__main__":
370
+ demo_bytes = _download_demo_image()
371
+
372
+ engine = VisualSecurityEngine()
373
+ injection_result = engine.detect_injection(demo_bytes)
374
+ cross_modal_result = engine.check_cross_modal(demo_bytes, "a cat sitting on a ledge")
375
+
376
+ print("Injection detection:", injection_result)
377
+ print("Cross-modal consistency:", cross_modal_result)
test_visual.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ if __name__ == "__main__":
2
+ print("Run visual engine tests once visual_engine.py is ready.")