File size: 3,560 Bytes
6903fe1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from typing import List

from fastapi import FastAPI, File, Form, UploadFile

from src.engines.visual_engine import PromptInjectionEngine, THREAT_DICTIONARY

app = FastAPI(title="Engine D (Prompt Injection) API")
_ENGINE: PromptInjectionEngine | None = None


@app.on_event("startup")
def load_engine() -> None:
    global _ENGINE
    if _ENGINE is None:
        _ENGINE = PromptInjectionEngine(use_onnx=True)


@app.get("/")
def health_check() -> dict:
    return {"status": "ok", "engine": "d"}


@app.post("/analyze_d")
async def analyze_engine_d(
    image: UploadFile = File(...),
    deep: bool = Form(True),
) -> dict:
    if _ENGINE is None:
        load_engine()
    engine = _ENGINE
    image_bytes = await image.read()
    text_payload = engine.extract_text(image_bytes)
    normalized_text = text_payload["normalized_text"]
    matched = [phrase for phrase in THREAT_DICTIONARY if phrase in normalized_text]
    scores = [score for _, score in text_payload.get("scored", [])]
    ocr_confidence = float(sum(scores) / len(scores)) if scores else 0.5
    if deep:
        injection_result = engine.detect_injection_from_text(normalized_text, matched_phrases=matched)
    else:
        injection_result = {
            "is_threat": bool(matched),
            "risk_score": 0.9 if matched else 0.1,
            "reason": "FastPathRegex",
        }
    return {
        "ocr": {**text_payload, "ocr_confidence": round(ocr_confidence, 3)},
        "injection": injection_result,
    }


@app.post("/analyze_d_batch")
async def analyze_engine_d_batch(
    images: List[UploadFile] = File(...),
    deep: bool = Form(True),
) -> dict:
    if _ENGINE is None:
        load_engine()
    engine = _ENGINE
    normalized_batch: List[str] = []
    ocr_payloads: List[dict] = []
    matched_batch: List[List[str]] = []

    for img in images:
        image_bytes = await img.read()
        payload = engine.extract_text(image_bytes)
        scores = [score for _, score in payload.get("scored", [])]
        payload["ocr_confidence"] = round(float(sum(scores) / len(scores)) if scores else 0.5, 3)
        ocr_payloads.append(payload)
        normalized_text = payload["normalized_text"]
        normalized_batch.append(normalized_text)
        matched_batch.append([phrase for phrase in THREAT_DICTIONARY if phrase in normalized_text])

    results: List[dict] = []
    if deep:
        # Batch the DeBERTa pipeline to utilize parallelism.
        classifier = engine._get_injection_classifier()
        classifications = classifier(normalized_batch, top_k=1)
        for idx, classification in enumerate(classifications):
            label = str(classification.get("label", "")).upper()
            score = float(classification.get("score", 0.0))
            is_injection = "1" in label or "INJECTION" in label
            risk_score = score if is_injection else 1.0 - score
            reason = f"Model={label or 'UNKNOWN'}; model_score={score:.3f}"
            if matched_batch[idx]:
                reason += f"; matched_phrases={', '.join(sorted(set(matched_batch[idx])))}"
            results.append(
                {"is_threat": bool(is_injection), "risk_score": round(risk_score, 3), "reason": reason}
            )
    else:
        for matched in matched_batch:
            results.append(
                {
                    "is_threat": bool(matched),
                    "risk_score": 0.9 if matched else 0.1,
                    "reason": "FastPathRegex",
                }
            )

    return {"ocr": ocr_payloads, "injection": results}