GitHub Actions commited on
Commit ·
5d50b8b
1
Parent(s): 8396c67
Deploy backend from GitHub 522e1ff559eaf4f3a628b450c12e01b910565458
Browse files- backend/app/api/routes.py +22 -43
- backend/app/core/config.py +4 -4
- backend/app/db/firestore.py +134 -45
- backend/app/services/groq_service.py +29 -53
- backend/app/services/hf_service.py +73 -52
- backend/tests/test_api.py +13 -38
- backend/tests/test_services.py +29 -18
backend/app/api/routes.py
CHANGED
|
@@ -1,31 +1,32 @@
|
|
| 1 |
"""
|
| 2 |
Main API routes for the LLM Misuse Detection system.
|
| 3 |
Endpoints: /api/analyze, /api/analyze/bulk, /api/results/{id}
|
| 4 |
-
Persistence: Firestore
|
| 5 |
"""
|
| 6 |
import hashlib
|
|
|
|
| 7 |
import time
|
| 8 |
from datetime import datetime, timezone
|
| 9 |
|
| 10 |
-
from fastapi import APIRouter,
|
| 11 |
|
| 12 |
from backend.app.api.models import (
|
| 13 |
-
AnalyzeRequest,
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
| 15 |
)
|
| 16 |
-
from backend.app.core.auth import get_current_user
|
| 17 |
from backend.app.core.config import settings
|
|
|
|
| 18 |
from backend.app.core.redis import check_rate_limit, get_cached, set_cached
|
| 19 |
-
from backend.app.db.firestore import
|
| 20 |
from backend.app.models.schemas import AnalysisResult
|
| 21 |
from backend.app.services.ensemble import compute_ensemble
|
| 22 |
-
from backend.app.services.hf_service import detect_ai_text, get_embeddings, detect_harm
|
| 23 |
from backend.app.services.groq_service import compute_perplexity
|
|
|
|
| 24 |
from backend.app.services.stylometry import compute_stylometry_score
|
| 25 |
from backend.app.services.vector_db import compute_cluster_score, upsert_embedding
|
| 26 |
-
from backend.app.core.logging import get_logger
|
| 27 |
-
|
| 28 |
-
import json
|
| 29 |
|
| 30 |
logger = get_logger(__name__)
|
| 31 |
router = APIRouter(prefix="/api", tags=["analysis"])
|
|
@@ -33,28 +34,24 @@ router = APIRouter(prefix="/api", tags=["analysis"])
|
|
| 33 |
COLLECTION = "analysis_results"
|
| 34 |
|
| 35 |
|
| 36 |
-
async def _analyze_text(text: str, user_id: str = None) -> dict:
|
| 37 |
"""Core analysis pipeline for a single text."""
|
| 38 |
start_time = time.time()
|
| 39 |
text_hash = hashlib.sha256(text.encode()).hexdigest()
|
| 40 |
|
| 41 |
-
# Check cache
|
| 42 |
cached = await get_cached(f"analysis:{text_hash}")
|
| 43 |
if cached:
|
| 44 |
return json.loads(cached)
|
| 45 |
|
| 46 |
-
# Step 1: AI detection
|
| 47 |
try:
|
| 48 |
p_ai = await detect_ai_text(text)
|
| 49 |
except Exception:
|
| 50 |
p_ai = None
|
| 51 |
|
| 52 |
-
# Step 2: Perplexity (cost-gated)
|
| 53 |
s_perp = None
|
| 54 |
if p_ai is not None and p_ai > settings.PERPLEXITY_THRESHOLD:
|
| 55 |
s_perp = await compute_perplexity(text)
|
| 56 |
|
| 57 |
-
# Step 3: Embeddings + cluster score
|
| 58 |
s_embed_cluster = None
|
| 59 |
try:
|
| 60 |
embeddings = await get_embeddings(text)
|
|
@@ -63,16 +60,10 @@ async def _analyze_text(text: str, user_id: str = None) -> dict:
|
|
| 63 |
except Exception:
|
| 64 |
pass
|
| 65 |
|
| 66 |
-
# Step 4: Harm/extremism
|
| 67 |
p_ext = await detect_harm(text)
|
| 68 |
-
|
| 69 |
-
# Step 5: Stylometry
|
| 70 |
s_styl = compute_stylometry_score(text)
|
| 71 |
-
|
| 72 |
-
# Step 6: Watermark placeholder
|
| 73 |
p_watermark = None
|
| 74 |
|
| 75 |
-
# Step 7: Ensemble
|
| 76 |
ensemble_result = compute_ensemble(
|
| 77 |
p_ai=p_ai,
|
| 78 |
s_perp=s_perp,
|
|
@@ -97,15 +88,12 @@ async def _analyze_text(text: str, user_id: str = None) -> dict:
|
|
| 97 |
"processing_time_ms": processing_time_ms,
|
| 98 |
}
|
| 99 |
|
| 100 |
-
# Cache
|
| 101 |
try:
|
| 102 |
await set_cached(f"analysis:{text_hash}", json.dumps(result), ttl=600)
|
| 103 |
except Exception:
|
| 104 |
pass
|
| 105 |
|
| 106 |
-
# Persist to Firestore
|
| 107 |
try:
|
| 108 |
-
db = get_db()
|
| 109 |
doc = AnalysisResult(
|
| 110 |
input_text=text,
|
| 111 |
text_hash=text_hash,
|
|
@@ -122,8 +110,8 @@ async def _analyze_text(text: str, user_id: str = None) -> dict:
|
|
| 122 |
completed_at=datetime.now(timezone.utc),
|
| 123 |
processing_time_ms=processing_time_ms,
|
| 124 |
)
|
| 125 |
-
|
| 126 |
-
result["id"] = doc.id
|
| 127 |
except Exception as e:
|
| 128 |
logger.warning("Firestore persist failed", error=str(e))
|
| 129 |
result["id"] = text_hash
|
|
@@ -152,9 +140,7 @@ async def analyze_text(request: AnalyzeRequest):
|
|
| 152 |
s_styl=result["s_styl"],
|
| 153 |
p_watermark=result["p_watermark"],
|
| 154 |
),
|
| 155 |
-
explainability=[
|
| 156 |
-
ExplainabilityItem(**e) for e in result["explainability"]
|
| 157 |
-
],
|
| 158 |
processing_time_ms=result["processing_time_ms"],
|
| 159 |
)
|
| 160 |
|
|
@@ -176,21 +162,16 @@ async def bulk_analyze(request: BulkAnalyzeRequest):
|
|
| 176 |
return {"results": results}
|
| 177 |
|
| 178 |
|
| 179 |
-
@router.get("/results/{result_id}")
|
| 180 |
-
async def get_result(
|
| 181 |
-
result_id: str,
|
| 182 |
-
user_id: str = Depends(get_current_user),
|
| 183 |
-
):
|
| 184 |
"""Fetch a previously computed analysis result by Firestore document ID."""
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
doc = doc_ref.get()
|
| 188 |
-
if not doc.exists:
|
| 189 |
raise HTTPException(status_code=404, detail="Result not found")
|
| 190 |
-
|
| 191 |
return AnalyzeResponse(
|
| 192 |
id=data["id"],
|
| 193 |
-
status=data
|
| 194 |
threat_score=data.get("threat_score"),
|
| 195 |
signals=SignalScores(
|
| 196 |
p_ai=data.get("p_ai"),
|
|
@@ -200,8 +181,6 @@ async def get_result(
|
|
| 200 |
s_styl=data.get("s_styl"),
|
| 201 |
p_watermark=data.get("p_watermark"),
|
| 202 |
),
|
| 203 |
-
explainability=[
|
| 204 |
-
ExplainabilityItem(**e) for e in (data.get("explainability") or [])
|
| 205 |
-
],
|
| 206 |
processing_time_ms=data.get("processing_time_ms"),
|
| 207 |
)
|
|
|
|
| 1 |
"""
|
| 2 |
Main API routes for the LLM Misuse Detection system.
|
| 3 |
Endpoints: /api/analyze, /api/analyze/bulk, /api/results/{id}
|
| 4 |
+
Persistence: Firestore via REST helpers.
|
| 5 |
"""
|
| 6 |
import hashlib
|
| 7 |
+
import json
|
| 8 |
import time
|
| 9 |
from datetime import datetime, timezone
|
| 10 |
|
| 11 |
+
from fastapi import APIRouter, HTTPException
|
| 12 |
|
| 13 |
from backend.app.api.models import (
|
| 14 |
+
AnalyzeRequest,
|
| 15 |
+
AnalyzeResponse,
|
| 16 |
+
BulkAnalyzeRequest,
|
| 17 |
+
ExplainabilityItem,
|
| 18 |
+
SignalScores,
|
| 19 |
)
|
|
|
|
| 20 |
from backend.app.core.config import settings
|
| 21 |
+
from backend.app.core.logging import get_logger
|
| 22 |
from backend.app.core.redis import check_rate_limit, get_cached, set_cached
|
| 23 |
+
from backend.app.db.firestore import get_document, save_document
|
| 24 |
from backend.app.models.schemas import AnalysisResult
|
| 25 |
from backend.app.services.ensemble import compute_ensemble
|
|
|
|
| 26 |
from backend.app.services.groq_service import compute_perplexity
|
| 27 |
+
from backend.app.services.hf_service import detect_ai_text, detect_harm, get_embeddings
|
| 28 |
from backend.app.services.stylometry import compute_stylometry_score
|
| 29 |
from backend.app.services.vector_db import compute_cluster_score, upsert_embedding
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
logger = get_logger(__name__)
|
| 32 |
router = APIRouter(prefix="/api", tags=["analysis"])
|
|
|
|
| 34 |
COLLECTION = "analysis_results"
|
| 35 |
|
| 36 |
|
| 37 |
+
async def _analyze_text(text: str, user_id: str | None = None) -> dict:
|
| 38 |
"""Core analysis pipeline for a single text."""
|
| 39 |
start_time = time.time()
|
| 40 |
text_hash = hashlib.sha256(text.encode()).hexdigest()
|
| 41 |
|
|
|
|
| 42 |
cached = await get_cached(f"analysis:{text_hash}")
|
| 43 |
if cached:
|
| 44 |
return json.loads(cached)
|
| 45 |
|
|
|
|
| 46 |
try:
|
| 47 |
p_ai = await detect_ai_text(text)
|
| 48 |
except Exception:
|
| 49 |
p_ai = None
|
| 50 |
|
|
|
|
| 51 |
s_perp = None
|
| 52 |
if p_ai is not None and p_ai > settings.PERPLEXITY_THRESHOLD:
|
| 53 |
s_perp = await compute_perplexity(text)
|
| 54 |
|
|
|
|
| 55 |
s_embed_cluster = None
|
| 56 |
try:
|
| 57 |
embeddings = await get_embeddings(text)
|
|
|
|
| 60 |
except Exception:
|
| 61 |
pass
|
| 62 |
|
|
|
|
| 63 |
p_ext = await detect_harm(text)
|
|
|
|
|
|
|
| 64 |
s_styl = compute_stylometry_score(text)
|
|
|
|
|
|
|
| 65 |
p_watermark = None
|
| 66 |
|
|
|
|
| 67 |
ensemble_result = compute_ensemble(
|
| 68 |
p_ai=p_ai,
|
| 69 |
s_perp=s_perp,
|
|
|
|
| 88 |
"processing_time_ms": processing_time_ms,
|
| 89 |
}
|
| 90 |
|
|
|
|
| 91 |
try:
|
| 92 |
await set_cached(f"analysis:{text_hash}", json.dumps(result), ttl=600)
|
| 93 |
except Exception:
|
| 94 |
pass
|
| 95 |
|
|
|
|
| 96 |
try:
|
|
|
|
| 97 |
doc = AnalysisResult(
|
| 98 |
input_text=text,
|
| 99 |
text_hash=text_hash,
|
|
|
|
| 110 |
completed_at=datetime.now(timezone.utc),
|
| 111 |
processing_time_ms=processing_time_ms,
|
| 112 |
)
|
| 113 |
+
saved = await save_document(COLLECTION, doc.id, doc.to_dict())
|
| 114 |
+
result["id"] = doc.id if saved else text_hash
|
| 115 |
except Exception as e:
|
| 116 |
logger.warning("Firestore persist failed", error=str(e))
|
| 117 |
result["id"] = text_hash
|
|
|
|
| 140 |
s_styl=result["s_styl"],
|
| 141 |
p_watermark=result["p_watermark"],
|
| 142 |
),
|
| 143 |
+
explainability=[ExplainabilityItem(**e) for e in result["explainability"]],
|
|
|
|
|
|
|
| 144 |
processing_time_ms=result["processing_time_ms"],
|
| 145 |
)
|
| 146 |
|
|
|
|
| 162 |
return {"results": results}
|
| 163 |
|
| 164 |
|
| 165 |
+
@router.get("/results/{result_id}", response_model=AnalyzeResponse)
|
| 166 |
+
async def get_result(result_id: str):
|
|
|
|
|
|
|
|
|
|
| 167 |
"""Fetch a previously computed analysis result by Firestore document ID."""
|
| 168 |
+
data = await get_document(COLLECTION, result_id)
|
| 169 |
+
if not data:
|
|
|
|
|
|
|
| 170 |
raise HTTPException(status_code=404, detail="Result not found")
|
| 171 |
+
|
| 172 |
return AnalyzeResponse(
|
| 173 |
id=data["id"],
|
| 174 |
+
status=data.get("status", "done"),
|
| 175 |
threat_score=data.get("threat_score"),
|
| 176 |
signals=SignalScores(
|
| 177 |
p_ai=data.get("p_ai"),
|
|
|
|
| 181 |
s_styl=data.get("s_styl"),
|
| 182 |
p_watermark=data.get("p_watermark"),
|
| 183 |
),
|
| 184 |
+
explainability=[ExplainabilityItem(**e) for e in (data.get("explainability") or [])],
|
|
|
|
|
|
|
| 185 |
processing_time_ms=data.get("processing_time_ms"),
|
| 186 |
)
|
backend/app/core/config.py
CHANGED
|
@@ -26,10 +26,10 @@ class Settings(BaseSettings):
|
|
| 26 |
|
| 27 |
# HuggingFace
|
| 28 |
HF_API_KEY: str = ""
|
| 29 |
-
HF_DETECTOR_PRIMARY: str = f"{_HF_ROUTER}/
|
| 30 |
-
HF_DETECTOR_FALLBACK: str =
|
| 31 |
-
HF_EMBEDDINGS_PRIMARY: str =
|
| 32 |
-
HF_EMBEDDINGS_FALLBACK: str =
|
| 33 |
HF_HARM_CLASSIFIER: str = f"{_HF_ROUTER}/facebook/roberta-hate-speech-dynabench-r4-target"
|
| 34 |
|
| 35 |
# Groq
|
|
|
|
| 26 |
|
| 27 |
# HuggingFace
|
| 28 |
HF_API_KEY: str = ""
|
| 29 |
+
HF_DETECTOR_PRIMARY: str = f"{_HF_ROUTER}/Hello-SimpleAI/chatgpt-detector-roberta"
|
| 30 |
+
HF_DETECTOR_FALLBACK: str = ""
|
| 31 |
+
HF_EMBEDDINGS_PRIMARY: str = ""
|
| 32 |
+
HF_EMBEDDINGS_FALLBACK: str = ""
|
| 33 |
HF_HARM_CLASSIFIER: str = f"{_HF_ROUTER}/facebook/roberta-hate-speech-dynabench-r4-target"
|
| 34 |
|
| 35 |
# Groq
|
backend/app/db/firestore.py
CHANGED
|
@@ -1,68 +1,157 @@
|
|
| 1 |
"""
|
| 2 |
-
|
| 3 |
|
| 4 |
-
|
| 5 |
-
-
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
"""
|
|
|
|
|
|
|
| 12 |
import json
|
| 13 |
-
import
|
|
|
|
| 14 |
|
| 15 |
-
import
|
| 16 |
-
|
| 17 |
|
| 18 |
from backend.app.core.config import settings
|
| 19 |
from backend.app.core.logging import get_logger
|
| 20 |
|
| 21 |
logger = get_logger(__name__)
|
| 22 |
|
| 23 |
-
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
|
| 27 |
-
def _fix_private_key(
|
| 28 |
-
"""
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
Fix: replace literal \\n with real newline in private_key only.
|
| 33 |
-
"""
|
| 34 |
-
if "private_key" in cred_dict:
|
| 35 |
-
cred_dict["private_key"] = cred_dict["private_key"].replace("\\n", "\n")
|
| 36 |
-
return cred_dict
|
| 37 |
|
| 38 |
|
| 39 |
def init_firebase() -> None:
|
| 40 |
-
"""
|
| 41 |
-
global
|
| 42 |
-
if
|
|
|
|
| 43 |
return
|
| 44 |
-
|
| 45 |
-
if settings.FIREBASE_CREDENTIALS_JSON:
|
| 46 |
cred_dict = json.loads(settings.FIREBASE_CREDENTIALS_JSON)
|
| 47 |
cred_dict = _fix_private_key(cred_dict)
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
|
| 65 |
def get_db():
|
| 66 |
-
if
|
| 67 |
-
|
| 68 |
-
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
+
Firestore client using the Firestore REST API over plain HTTPS.
|
| 3 |
|
| 4 |
+
Why REST instead of firebase-admin + gRPC:
|
| 5 |
+
The firebase-admin SDK uses gRPC for Firestore. When FIREBASE_CREDENTIALS_JSON
|
| 6 |
+
is stored as an env-var in HF Spaces, the private_key newlines are double-escaped
|
| 7 |
+
(\\n instead of \n), causing 'invalid_grant: Invalid JWT Signature' errors that
|
| 8 |
+
fire in a tight background loop and spam the logs. The REST approach uses
|
| 9 |
+
google-auth (already installed) directly over HTTPS — no gRPC, no background
|
| 10 |
+
token-refresh loop, and the newline fix is applied once at startup.
|
| 11 |
|
| 12 |
+
Env vars required:
|
| 13 |
+
FIREBASE_CREDENTIALS_JSON – service account JSON string
|
| 14 |
+
FIREBASE_PROJECT_ID – e.g. "fir-config-d3c36"
|
| 15 |
"""
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
import json
|
| 19 |
+
import httpx
|
| 20 |
+
from typing import Any
|
| 21 |
|
| 22 |
+
import google.oauth2.service_account as sa
|
| 23 |
+
import google.auth.transport.requests as ga_requests
|
| 24 |
|
| 25 |
from backend.app.core.config import settings
|
| 26 |
from backend.app.core.logging import get_logger
|
| 27 |
|
| 28 |
logger = get_logger(__name__)
|
| 29 |
|
| 30 |
+
_SCOPES = ["https://www.googleapis.com/auth/datastore"]
|
| 31 |
+
_FIRESTORE_BASE = "https://firestore.googleapis.com/v1"
|
| 32 |
+
|
| 33 |
+
_credentials: sa.Credentials | None = None
|
| 34 |
+
_project_id: str = ""
|
| 35 |
+
_enabled: bool = False
|
| 36 |
|
| 37 |
|
| 38 |
+
def _fix_private_key(d: dict) -> dict:
|
| 39 |
+
"""Unescape double-escaped newlines in private_key (common in env-var pastes)."""
|
| 40 |
+
if "private_key" in d:
|
| 41 |
+
d["private_key"] = d["private_key"].replace("\\n", "\n")
|
| 42 |
+
return d
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
|
| 45 |
def init_firebase() -> None:
|
| 46 |
+
"""Load service-account credentials. Non-fatal if misconfigured."""
|
| 47 |
+
global _credentials, _project_id, _enabled
|
| 48 |
+
if not settings.FIREBASE_CREDENTIALS_JSON:
|
| 49 |
+
logger.warning("FIREBASE_CREDENTIALS_JSON not set – Firestore disabled")
|
| 50 |
return
|
| 51 |
+
try:
|
|
|
|
| 52 |
cred_dict = json.loads(settings.FIREBASE_CREDENTIALS_JSON)
|
| 53 |
cred_dict = _fix_private_key(cred_dict)
|
| 54 |
+
_credentials = sa.Credentials.from_service_account_info(cred_dict, scopes=_SCOPES)
|
| 55 |
+
_project_id = settings.FIREBASE_PROJECT_ID or cred_dict.get("project_id", "")
|
| 56 |
+
# Validate credentials once at startup to avoid repeated runtime failures.
|
| 57 |
+
req = ga_requests.Request()
|
| 58 |
+
_credentials.refresh(req)
|
| 59 |
+
_enabled = True
|
| 60 |
+
logger.info("Firebase REST client initialised", project=_project_id)
|
| 61 |
+
except Exception as e:
|
| 62 |
+
_credentials = None
|
| 63 |
+
_enabled = False
|
| 64 |
+
logger.warning("Firebase init failed – Firestore disabled", error=str(e))
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def _auth_headers() -> dict:
|
| 68 |
+
"""Return a fresh Bearer token header (refreshes automatically when needed)."""
|
| 69 |
+
req = ga_requests.Request()
|
| 70 |
+
_credentials.refresh(req)
|
| 71 |
+
return {"Authorization": f"Bearer {_credentials.token}"}
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def _collection_url(collection: str) -> str:
|
| 75 |
+
return f"{_FIRESTORE_BASE}/projects/{_project_id}/databases/(default)/documents/{collection}"
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def _doc_url(collection: str, doc_id: str) -> str:
|
| 79 |
+
return f"{_collection_url(collection)}/{doc_id}"
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _to_firestore_value(v: Any) -> dict:
|
| 83 |
+
"""Convert a Python value to a Firestore REST value object."""
|
| 84 |
+
if isinstance(v, bool):
|
| 85 |
+
return {"booleanValue": v}
|
| 86 |
+
if isinstance(v, int):
|
| 87 |
+
return {"integerValue": str(v)}
|
| 88 |
+
if isinstance(v, float):
|
| 89 |
+
return {"doubleValue": v}
|
| 90 |
+
if isinstance(v, str):
|
| 91 |
+
return {"stringValue": v}
|
| 92 |
+
if v is None:
|
| 93 |
+
return {"nullValue": None}
|
| 94 |
+
if isinstance(v, dict):
|
| 95 |
+
return {"mapValue": {"fields": {k: _to_firestore_value(u) for k, u in v.items()}}}
|
| 96 |
+
if isinstance(v, list):
|
| 97 |
+
return {"arrayValue": {"values": [_to_firestore_value(i) for i in v]}}
|
| 98 |
+
return {"stringValue": str(v)}
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def _from_firestore_value(v: dict) -> Any:
|
| 102 |
+
"""Convert a Firestore REST value object to a Python value."""
|
| 103 |
+
if "stringValue" in v: return v["stringValue"]
|
| 104 |
+
if "integerValue" in v: return int(v["integerValue"])
|
| 105 |
+
if "doubleValue" in v: return float(v["doubleValue"])
|
| 106 |
+
if "booleanValue" in v: return v["booleanValue"]
|
| 107 |
+
if "nullValue" in v: return None
|
| 108 |
+
if "mapValue" in v: return {k: _from_firestore_value(u) for k, u in v["mapValue"].get("fields", {}).items()}
|
| 109 |
+
if "arrayValue" in v: return [_from_firestore_value(i) for i in v["arrayValue"].get("values", [])]
|
| 110 |
+
return None
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
# ---- Public helpers --------------------------------------------------------
|
| 114 |
+
|
| 115 |
+
async def save_document(collection: str, doc_id: str, data: dict) -> bool:
|
| 116 |
+
"""Create or overwrite a Firestore document. Returns True on success."""
|
| 117 |
+
if not _enabled:
|
| 118 |
+
return False
|
| 119 |
+
try:
|
| 120 |
+
fields = {k: _to_firestore_value(v) for k, v in data.items()}
|
| 121 |
+
url = _doc_url(collection, doc_id)
|
| 122 |
+
async with httpx.AsyncClient(timeout=10.0) as client:
|
| 123 |
+
resp = await client.patch(
|
| 124 |
+
url,
|
| 125 |
+
json={"fields": fields},
|
| 126 |
+
headers=_auth_headers(),
|
| 127 |
+
)
|
| 128 |
+
resp.raise_for_status()
|
| 129 |
+
return True
|
| 130 |
+
except Exception as e:
|
| 131 |
+
logger.warning("Firestore save_document failed", collection=collection, doc_id=doc_id, error=str(e))
|
| 132 |
+
return False
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
async def get_document(collection: str, doc_id: str) -> dict | None:
|
| 136 |
+
"""Fetch a single Firestore document. Returns None if not found or disabled."""
|
| 137 |
+
if not _enabled:
|
| 138 |
+
return None
|
| 139 |
+
try:
|
| 140 |
+
url = _doc_url(collection, doc_id)
|
| 141 |
+
async with httpx.AsyncClient(timeout=10.0) as client:
|
| 142 |
+
resp = await client.get(url, headers=_auth_headers())
|
| 143 |
+
if resp.status_code == 404:
|
| 144 |
+
return None
|
| 145 |
+
resp.raise_for_status()
|
| 146 |
+
fields = resp.json().get("fields", {})
|
| 147 |
+
return {k: _from_firestore_value(v) for k, v in fields.items()}
|
| 148 |
+
except Exception as e:
|
| 149 |
+
logger.warning("Firestore get_document failed", collection=collection, doc_id=doc_id, error=str(e))
|
| 150 |
+
return None
|
| 151 |
|
| 152 |
|
| 153 |
def get_db():
|
| 154 |
+
"""Legacy shim for code that calls get_db(). Returns None if Firestore is disabled."""
|
| 155 |
+
if not _enabled:
|
| 156 |
+
return None
|
| 157 |
+
return True # callers should use save_document/get_document directly
|
backend/app/services/groq_service.py
CHANGED
|
@@ -1,46 +1,32 @@
|
|
| 1 |
-
"""
|
| 2 |
-
|
| 3 |
-
|
| 4 |
|
| 5 |
-
Env vars: GROQ_API_KEY, GROQ_MODEL, GROQ_BASE_URL
|
| 6 |
-
"""
|
| 7 |
import math
|
| 8 |
-
import httpx
|
| 9 |
-
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
|
| 10 |
from typing import Optional
|
| 11 |
|
|
|
|
|
|
|
| 12 |
from backend.app.core.config import settings
|
| 13 |
from backend.app.core.logging import get_logger
|
| 14 |
|
| 15 |
logger = get_logger(__name__)
|
| 16 |
|
| 17 |
-
_TIMEOUT = httpx.Timeout(
|
|
|
|
| 18 |
|
| 19 |
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
|
| 24 |
-
@retry(
|
| 25 |
-
stop=stop_after_attempt(3),
|
| 26 |
-
wait=wait_exponential(multiplier=1, min=2, max=15),
|
| 27 |
-
retry=retry_if_exception_type((httpx.HTTPStatusError, httpx.ConnectError)),
|
| 28 |
-
)
|
| 29 |
-
async def _groq_chat_completion(text: str) -> dict:
|
| 30 |
-
"""Call Groq chat completion with logprobs enabled.
|
| 31 |
-
Note: Input is truncated to 2000 chars for cost control. Perplexity
|
| 32 |
-
scores for longer texts reflect only the first 2000 characters.
|
| 33 |
-
"""
|
| 34 |
headers = {
|
| 35 |
"Authorization": f"Bearer {settings.GROQ_API_KEY}",
|
| 36 |
"Content-Type": "application/json",
|
| 37 |
}
|
| 38 |
payload = {
|
| 39 |
-
"model":
|
| 40 |
-
"messages": [
|
| 41 |
-
{"role": "system", "content": "Repeat the following text exactly:"},
|
| 42 |
-
{"role": "user", "content": text[:2000]}, # Truncated for cost control
|
| 43 |
-
],
|
| 44 |
"max_tokens": 1,
|
| 45 |
"temperature": 0,
|
| 46 |
"logprobs": True,
|
|
@@ -52,54 +38,44 @@ async def _groq_chat_completion(text: str) -> dict:
|
|
| 52 |
json=payload,
|
| 53 |
headers=headers,
|
| 54 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
resp.raise_for_status()
|
| 56 |
return resp.json()
|
| 57 |
|
| 58 |
|
| 59 |
async def compute_perplexity(text: str) -> Optional[float]:
|
| 60 |
-
"""
|
| 61 |
-
Compute a normalized perplexity score using Groq Llama endpoints.
|
| 62 |
-
Returns a score between 0 and 1 where higher = more anomalous.
|
| 63 |
-
|
| 64 |
-
Strategy: Use logprobs from a single completion call to estimate
|
| 65 |
-
the model's surprise at the input text.
|
| 66 |
-
"""
|
| 67 |
try:
|
| 68 |
result = await _groq_chat_completion(text)
|
|
|
|
|
|
|
|
|
|
| 69 |
choices = result.get("choices", [])
|
| 70 |
if not choices:
|
| 71 |
return None
|
| 72 |
|
| 73 |
-
logprobs_data = choices[0].get("logprobs"
|
| 74 |
-
|
| 75 |
-
|
|
|
|
| 76 |
usage = result.get("usage", {})
|
| 77 |
prompt_tokens = usage.get("prompt_tokens", 0)
|
| 78 |
if prompt_tokens > 0:
|
| 79 |
-
text_len = len(text.split())
|
| 80 |
-
ratio = prompt_tokens /
|
| 81 |
-
|
| 82 |
-
return min(1.0, max(0.0, (ratio - 1.0) / 2.0))
|
| 83 |
return None
|
| 84 |
|
| 85 |
-
|
| 86 |
-
if not content:
|
| 87 |
-
return None
|
| 88 |
-
|
| 89 |
-
# Compute perplexity from log-probabilities
|
| 90 |
-
log_probs = []
|
| 91 |
-
for token_info in content:
|
| 92 |
-
lp = token_info.get("logprob")
|
| 93 |
-
if lp is not None:
|
| 94 |
-
log_probs.append(lp)
|
| 95 |
-
|
| 96 |
if not log_probs:
|
| 97 |
return None
|
| 98 |
|
| 99 |
avg_log_prob = sum(log_probs) / len(log_probs)
|
| 100 |
perplexity = math.exp(-avg_log_prob)
|
| 101 |
-
|
| 102 |
-
normalized = min(1.0, max(0.0, (math.log(perplexity + 1) / math.log(101))))
|
| 103 |
return round(normalized, 4)
|
| 104 |
except Exception as e:
|
| 105 |
logger.warning("Groq perplexity computation failed", error=str(e))
|
|
|
|
| 1 |
+
"""Groq API client for optional perplexity scoring."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
|
|
|
|
|
|
|
| 5 |
import math
|
|
|
|
|
|
|
| 6 |
from typing import Optional
|
| 7 |
|
| 8 |
+
import httpx
|
| 9 |
+
|
| 10 |
from backend.app.core.config import settings
|
| 11 |
from backend.app.core.logging import get_logger
|
| 12 |
|
| 13 |
logger = get_logger(__name__)
|
| 14 |
|
| 15 |
+
_TIMEOUT = httpx.Timeout(30.0, connect=10.0)
|
| 16 |
+
_LOGPROBS_MODEL = "llama-3.1-8b-instant"
|
| 17 |
|
| 18 |
|
| 19 |
+
async def _groq_chat_completion(text: str) -> Optional[dict]:
|
| 20 |
+
if not settings.GROQ_API_KEY:
|
| 21 |
+
return None
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
headers = {
|
| 24 |
"Authorization": f"Bearer {settings.GROQ_API_KEY}",
|
| 25 |
"Content-Type": "application/json",
|
| 26 |
}
|
| 27 |
payload = {
|
| 28 |
+
"model": _LOGPROBS_MODEL,
|
| 29 |
+
"messages": [{"role": "user", "content": text[:1500]}],
|
|
|
|
|
|
|
|
|
|
| 30 |
"max_tokens": 1,
|
| 31 |
"temperature": 0,
|
| 32 |
"logprobs": True,
|
|
|
|
| 38 |
json=payload,
|
| 39 |
headers=headers,
|
| 40 |
)
|
| 41 |
+
# 4xx means unsupported model/params for this key-tier; do not spam retries.
|
| 42 |
+
if 400 <= resp.status_code < 500:
|
| 43 |
+
logger.info("Groq perplexity unavailable for current deployment", status_code=resp.status_code)
|
| 44 |
+
return None
|
| 45 |
resp.raise_for_status()
|
| 46 |
return resp.json()
|
| 47 |
|
| 48 |
|
| 49 |
async def compute_perplexity(text: str) -> Optional[float]:
|
| 50 |
+
"""Compute a normalized perplexity score (0-1). Returns None on failure."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
try:
|
| 52 |
result = await _groq_chat_completion(text)
|
| 53 |
+
if not result:
|
| 54 |
+
return None
|
| 55 |
+
|
| 56 |
choices = result.get("choices", [])
|
| 57 |
if not choices:
|
| 58 |
return None
|
| 59 |
|
| 60 |
+
logprobs_data = choices[0].get("logprobs") or {}
|
| 61 |
+
content = logprobs_data.get("content") or []
|
| 62 |
+
|
| 63 |
+
if not content:
|
| 64 |
usage = result.get("usage", {})
|
| 65 |
prompt_tokens = usage.get("prompt_tokens", 0)
|
| 66 |
if prompt_tokens > 0:
|
| 67 |
+
text_len = max(len(text.split()), 1)
|
| 68 |
+
ratio = prompt_tokens / text_len
|
| 69 |
+
return round(min(1.0, max(0.0, (ratio - 1.0) / 2.0)), 4)
|
|
|
|
| 70 |
return None
|
| 71 |
|
| 72 |
+
log_probs = [t["logprob"] for t in content if t.get("logprob") is not None]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
if not log_probs:
|
| 74 |
return None
|
| 75 |
|
| 76 |
avg_log_prob = sum(log_probs) / len(log_probs)
|
| 77 |
perplexity = math.exp(-avg_log_prob)
|
| 78 |
+
normalized = min(1.0, max(0.0, math.log(perplexity + 1) / math.log(101)))
|
|
|
|
| 79 |
return round(normalized, 4)
|
| 80 |
except Exception as e:
|
| 81 |
logger.warning("Groq perplexity computation failed", error=str(e))
|
backend/app/services/hf_service.py
CHANGED
|
@@ -1,26 +1,24 @@
|
|
| 1 |
-
"""
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
"""
|
| 9 |
import httpx
|
| 10 |
-
from tenacity import retry,
|
| 11 |
-
from typing import List, Optional, Dict, Any
|
| 12 |
|
| 13 |
from backend.app.core.config import settings
|
| 14 |
from backend.app.core.logging import get_logger
|
| 15 |
|
| 16 |
logger = get_logger(__name__)
|
| 17 |
|
| 18 |
-
_HEADERS = lambda: {"Authorization": f"Bearer {settings.HF_API_KEY}"}
|
| 19 |
_TIMEOUT = httpx.Timeout(30.0, connect=10.0)
|
|
|
|
| 20 |
|
| 21 |
|
| 22 |
-
|
| 23 |
-
|
| 24 |
|
| 25 |
|
| 26 |
@retry(
|
|
@@ -28,74 +26,97 @@ class HFServiceError(Exception):
|
|
| 28 |
wait=wait_exponential(multiplier=1, min=1, max=10),
|
| 29 |
retry=retry_if_exception_type((httpx.HTTPStatusError, httpx.ConnectError)),
|
| 30 |
)
|
| 31 |
-
async def
|
| 32 |
async with httpx.AsyncClient(timeout=_TIMEOUT) as client:
|
| 33 |
-
resp = await client.post(url, json=payload, headers=
|
| 34 |
resp.raise_for_status()
|
| 35 |
return resp.json()
|
| 36 |
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
async def detect_ai_text(text: str) -> float:
|
| 39 |
-
"""
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
"""
|
| 43 |
-
scores = []
|
| 44 |
-
for url in [settings.HF_DETECTOR_PRIMARY, settings.HF_DETECTOR_FALLBACK]:
|
| 45 |
try:
|
| 46 |
-
result = await
|
| 47 |
-
# HF classification returns [[{label, score}, ...]]
|
| 48 |
if isinstance(result, list) and len(result) > 0:
|
| 49 |
labels = result[0] if isinstance(result[0], list) else result
|
| 50 |
for item in labels:
|
| 51 |
label = item.get("label", "").lower()
|
| 52 |
-
if
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
break
|
| 55 |
else:
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
scores.append(labels[0].get("score", 0.5))
|
| 59 |
except Exception as e:
|
| 60 |
logger.warning("HF detector call failed", url=url, error=str(e))
|
|
|
|
| 61 |
if not scores:
|
| 62 |
-
raise
|
| 63 |
-
return sum(scores) / len(scores)
|
| 64 |
|
| 65 |
|
| 66 |
-
async def get_embeddings(text: str) ->
|
| 67 |
-
"""
|
| 68 |
-
for url in
|
| 69 |
try:
|
| 70 |
-
result = await
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
if isinstance(result[0], list):
|
| 76 |
-
return result[0]
|
| 77 |
-
return result
|
| 78 |
except Exception as e:
|
| 79 |
logger.warning("HF embeddings call failed", url=url, error=str(e))
|
| 80 |
-
|
|
|
|
|
|
|
| 81 |
|
| 82 |
|
| 83 |
async def detect_harm(text: str) -> float:
|
| 84 |
-
"""
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
try:
|
| 89 |
-
result = await
|
| 90 |
if isinstance(result, list) and len(result) > 0:
|
| 91 |
labels = result[0] if isinstance(result[0], list) else result
|
| 92 |
for item in labels:
|
| 93 |
label = item.get("label", "").lower()
|
| 94 |
-
if label in ("hate", "toxic", "harmful", "
|
| 95 |
-
return item["score"]
|
| 96 |
-
|
| 97 |
-
if labels:
|
| 98 |
-
return max(item.get("score", 0.0) for item in labels)
|
| 99 |
return 0.0
|
| 100 |
except Exception as e:
|
| 101 |
logger.warning("HF harm classifier failed", error=str(e))
|
|
|
|
| 1 |
+
"""Hugging Face Inference API helpers with resilient fallbacks."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import hashlib
|
| 6 |
+
from typing import Any
|
| 7 |
+
|
|
|
|
| 8 |
import httpx
|
| 9 |
+
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
|
|
|
|
| 10 |
|
| 11 |
from backend.app.core.config import settings
|
| 12 |
from backend.app.core.logging import get_logger
|
| 13 |
|
| 14 |
logger = get_logger(__name__)
|
| 15 |
|
|
|
|
| 16 |
_TIMEOUT = httpx.Timeout(30.0, connect=10.0)
|
| 17 |
+
_LOCAL_EMBEDDING_DIM = 384
|
| 18 |
|
| 19 |
|
| 20 |
+
def _headers() -> dict:
|
| 21 |
+
return {"Authorization": f"Bearer {settings.HF_API_KEY}"}
|
| 22 |
|
| 23 |
|
| 24 |
@retry(
|
|
|
|
| 26 |
wait=wait_exponential(multiplier=1, min=1, max=10),
|
| 27 |
retry=retry_if_exception_type((httpx.HTTPStatusError, httpx.ConnectError)),
|
| 28 |
)
|
| 29 |
+
async def _hf_post(url: str, payload: dict) -> Any:
|
| 30 |
async with httpx.AsyncClient(timeout=_TIMEOUT) as client:
|
| 31 |
+
resp = await client.post(url, json=payload, headers=_headers())
|
| 32 |
resp.raise_for_status()
|
| 33 |
return resp.json()
|
| 34 |
|
| 35 |
|
| 36 |
+
def _configured_urls(*urls: str) -> list[str]:
|
| 37 |
+
return [url for url in urls if url and url.strip()]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _local_embedding(text: str, dim: int = _LOCAL_EMBEDDING_DIM) -> list[float]:
|
| 41 |
+
"""Deterministic no-network embedding fallback to keep pipeline stable."""
|
| 42 |
+
seed = hashlib.sha256(text.encode("utf-8")).digest()
|
| 43 |
+
values: list[float] = []
|
| 44 |
+
block = seed
|
| 45 |
+
while len(values) < dim:
|
| 46 |
+
block = hashlib.sha256(block + text.encode("utf-8")).digest()
|
| 47 |
+
for byte in block:
|
| 48 |
+
values.append((byte / 127.5) - 1.0)
|
| 49 |
+
if len(values) == dim:
|
| 50 |
+
break
|
| 51 |
+
return values
|
| 52 |
+
|
| 53 |
+
|
| 54 |
async def detect_ai_text(text: str) -> float:
|
| 55 |
+
"""Returns probability that text is AI-generated (0-1)."""
|
| 56 |
+
scores: list[float] = []
|
| 57 |
+
for url in _configured_urls(settings.HF_DETECTOR_PRIMARY, settings.HF_DETECTOR_FALLBACK):
|
|
|
|
|
|
|
|
|
|
| 58 |
try:
|
| 59 |
+
result = await _hf_post(url, {"inputs": text})
|
|
|
|
| 60 |
if isinstance(result, list) and len(result) > 0:
|
| 61 |
labels = result[0] if isinstance(result[0], list) else result
|
| 62 |
for item in labels:
|
| 63 |
label = item.get("label", "").lower()
|
| 64 |
+
if any(
|
| 65 |
+
k in label
|
| 66 |
+
for k in (
|
| 67 |
+
"ai",
|
| 68 |
+
"fake",
|
| 69 |
+
"machine",
|
| 70 |
+
"generated",
|
| 71 |
+
"chatgpt",
|
| 72 |
+
"gpt",
|
| 73 |
+
"class_1",
|
| 74 |
+
"label_1",
|
| 75 |
+
)
|
| 76 |
+
):
|
| 77 |
+
scores.append(float(item["score"]))
|
| 78 |
break
|
| 79 |
else:
|
| 80 |
+
best = max(labels, key=lambda x: x.get("score", 0))
|
| 81 |
+
scores.append(float(best.get("score", 0.5)))
|
|
|
|
| 82 |
except Exception as e:
|
| 83 |
logger.warning("HF detector call failed", url=url, error=str(e))
|
| 84 |
+
|
| 85 |
if not scores:
|
| 86 |
+
raise Exception("All AI detectors failed")
|
| 87 |
+
return round(sum(scores) / len(scores), 4)
|
| 88 |
|
| 89 |
|
| 90 |
+
async def get_embeddings(text: str) -> list[float]:
|
| 91 |
+
"""Returns embedding vector, falling back to deterministic local embedding."""
|
| 92 |
+
for url in _configured_urls(settings.HF_EMBEDDINGS_PRIMARY, settings.HF_EMBEDDINGS_FALLBACK):
|
| 93 |
try:
|
| 94 |
+
result = await _hf_post(url, {"inputs": text})
|
| 95 |
+
while isinstance(result, list) and result and isinstance(result[0], list):
|
| 96 |
+
result = result[0]
|
| 97 |
+
if isinstance(result, list) and result and isinstance(result[0], (float, int)):
|
| 98 |
+
return [float(v) for v in result]
|
|
|
|
|
|
|
|
|
|
| 99 |
except Exception as e:
|
| 100 |
logger.warning("HF embeddings call failed", url=url, error=str(e))
|
| 101 |
+
|
| 102 |
+
logger.info("Using local deterministic embeddings fallback")
|
| 103 |
+
return _local_embedding(text)
|
| 104 |
|
| 105 |
|
| 106 |
async def detect_harm(text: str) -> float:
|
| 107 |
+
"""Returns probability of harmful content (0-1). Non-fatal on failure."""
|
| 108 |
+
if not settings.HF_HARM_CLASSIFIER:
|
| 109 |
+
return 0.0
|
| 110 |
+
|
| 111 |
try:
|
| 112 |
+
result = await _hf_post(settings.HF_HARM_CLASSIFIER, {"inputs": text})
|
| 113 |
if isinstance(result, list) and len(result) > 0:
|
| 114 |
labels = result[0] if isinstance(result[0], list) else result
|
| 115 |
for item in labels:
|
| 116 |
label = item.get("label", "").lower()
|
| 117 |
+
if any(k in label for k in ("hate", "toxic", "harmful", "hateful", "target")):
|
| 118 |
+
return float(item["score"])
|
| 119 |
+
return float(max(labels, key=lambda x: x.get("score", 0)).get("score", 0.0))
|
|
|
|
|
|
|
| 120 |
return 0.0
|
| 121 |
except Exception as e:
|
| 122 |
logger.warning("HF harm classifier failed", error=str(e))
|
backend/tests/test_api.py
CHANGED
|
@@ -9,29 +9,16 @@ from fastapi.testclient import TestClient
|
|
| 9 |
|
| 10 |
def _make_client():
|
| 11 |
"""Build a TestClient with Firebase init and Firestore writes mocked out."""
|
| 12 |
-
# Patch
|
| 13 |
-
with patch("backend.app.db.firestore.
|
| 14 |
-
patch("backend.app.db.firestore.
|
|
|
|
|
|
|
| 15 |
from backend.app.main import app
|
| 16 |
|
| 17 |
-
# Mock get_db() so Firestore document writes are no-ops
|
| 18 |
-
mock_db = MagicMock()
|
| 19 |
-
mock_collection = MagicMock()
|
| 20 |
-
mock_doc_ref = MagicMock()
|
| 21 |
-
mock_db.collection.return_value = mock_collection
|
| 22 |
-
mock_collection.document.return_value = mock_doc_ref
|
| 23 |
-
mock_doc_ref.set.return_value = None
|
| 24 |
-
|
| 25 |
-
# Mock get_result Firestore read
|
| 26 |
-
mock_existing_doc = MagicMock()
|
| 27 |
-
mock_existing_doc.exists = False
|
| 28 |
-
mock_doc_ref.get.return_value = mock_existing_doc
|
| 29 |
-
|
| 30 |
app.dependency_overrides = {}
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
client = TestClient(app)
|
| 34 |
-
return client, mock_db
|
| 35 |
|
| 36 |
|
| 37 |
@pytest.fixture
|
|
@@ -65,15 +52,11 @@ class TestAnalyzeEndpoint:
|
|
| 65 |
@patch("backend.app.api.routes.detect_harm", new_callable=AsyncMock, return_value=0.2)
|
| 66 |
@patch("backend.app.api.routes.get_cached", new_callable=AsyncMock, return_value=None)
|
| 67 |
@patch("backend.app.api.routes.set_cached", new_callable=AsyncMock)
|
| 68 |
-
@patch("backend.app.
|
| 69 |
def test_analyze_returns_scores(
|
| 70 |
-
self,
|
| 71 |
mock_cluster, mock_embed, mock_perp, mock_ai, mock_rate, client
|
| 72 |
):
|
| 73 |
-
mock_db = MagicMock()
|
| 74 |
-
mock_db.collection.return_value.document.return_value.set.return_value = None
|
| 75 |
-
mock_get_db.return_value = mock_db
|
| 76 |
-
|
| 77 |
response = client.post(
|
| 78 |
"/api/analyze",
|
| 79 |
json={"text": "This is a test text that should be analyzed for potential misuse patterns."},
|
|
@@ -110,15 +93,11 @@ class TestAttackSimulations:
|
|
| 110 |
@patch("backend.app.api.routes.detect_harm", new_callable=AsyncMock, return_value=0.9)
|
| 111 |
@patch("backend.app.api.routes.get_cached", new_callable=AsyncMock, return_value=None)
|
| 112 |
@patch("backend.app.api.routes.set_cached", new_callable=AsyncMock)
|
| 113 |
-
@patch("backend.app.
|
| 114 |
def test_high_threat_detection(
|
| 115 |
-
self,
|
| 116 |
mock_cluster, mock_embed, mock_perp, mock_ai, mock_rate, client
|
| 117 |
):
|
| 118 |
-
mock_db = MagicMock()
|
| 119 |
-
mock_db.collection.return_value.document.return_value.set.return_value = None
|
| 120 |
-
mock_get_db.return_value = mock_db
|
| 121 |
-
|
| 122 |
response = client.post(
|
| 123 |
"/api/analyze",
|
| 124 |
json={"text": "Simulated high-threat content for testing purposes only. This is a test."},
|
|
@@ -135,15 +114,11 @@ class TestAttackSimulations:
|
|
| 135 |
@patch("backend.app.api.routes.detect_harm", new_callable=AsyncMock, return_value=0.02)
|
| 136 |
@patch("backend.app.api.routes.get_cached", new_callable=AsyncMock, return_value=None)
|
| 137 |
@patch("backend.app.api.routes.set_cached", new_callable=AsyncMock)
|
| 138 |
-
@patch("backend.app.
|
| 139 |
def test_benign_text_low_threat(
|
| 140 |
-
self,
|
| 141 |
mock_cluster, mock_embed, mock_ai, mock_rate, client
|
| 142 |
):
|
| 143 |
-
mock_db = MagicMock()
|
| 144 |
-
mock_db.collection.return_value.document.return_value.set.return_value = None
|
| 145 |
-
mock_get_db.return_value = mock_db
|
| 146 |
-
|
| 147 |
response = client.post(
|
| 148 |
"/api/analyze",
|
| 149 |
json={"text": "The weather today is sunny with clear skies and mild temperatures across the region."},
|
|
|
|
| 9 |
|
| 10 |
def _make_client():
|
| 11 |
"""Build a TestClient with Firebase init and Firestore writes mocked out."""
|
| 12 |
+
# Patch init_firebase and the _enabled flag so the app starts without real credentials
|
| 13 |
+
with patch("backend.app.db.firestore.init_firebase"), \
|
| 14 |
+
patch("backend.app.db.firestore._enabled", True), \
|
| 15 |
+
patch("backend.app.db.firestore.save_document", new_callable=AsyncMock, return_value=True), \
|
| 16 |
+
patch("backend.app.db.firestore.get_document", new_callable=AsyncMock, return_value=None):
|
| 17 |
from backend.app.main import app
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
app.dependency_overrides = {}
|
| 20 |
+
client = TestClient(app)
|
| 21 |
+
return client, None
|
|
|
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
@pytest.fixture
|
|
|
|
| 52 |
@patch("backend.app.api.routes.detect_harm", new_callable=AsyncMock, return_value=0.2)
|
| 53 |
@patch("backend.app.api.routes.get_cached", new_callable=AsyncMock, return_value=None)
|
| 54 |
@patch("backend.app.api.routes.set_cached", new_callable=AsyncMock)
|
| 55 |
+
@patch("backend.app.db.firestore.save_document", new_callable=AsyncMock, return_value=True)
|
| 56 |
def test_analyze_returns_scores(
|
| 57 |
+
self, mock_save, mock_set_cache, mock_get_cache, mock_harm, mock_upsert,
|
| 58 |
mock_cluster, mock_embed, mock_perp, mock_ai, mock_rate, client
|
| 59 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
response = client.post(
|
| 61 |
"/api/analyze",
|
| 62 |
json={"text": "This is a test text that should be analyzed for potential misuse patterns."},
|
|
|
|
| 93 |
@patch("backend.app.api.routes.detect_harm", new_callable=AsyncMock, return_value=0.9)
|
| 94 |
@patch("backend.app.api.routes.get_cached", new_callable=AsyncMock, return_value=None)
|
| 95 |
@patch("backend.app.api.routes.set_cached", new_callable=AsyncMock)
|
| 96 |
+
@patch("backend.app.db.firestore.save_document", new_callable=AsyncMock, return_value=True)
|
| 97 |
def test_high_threat_detection(
|
| 98 |
+
self, mock_save, mock_set_cache, mock_get_cache, mock_harm, mock_upsert,
|
| 99 |
mock_cluster, mock_embed, mock_perp, mock_ai, mock_rate, client
|
| 100 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
response = client.post(
|
| 102 |
"/api/analyze",
|
| 103 |
json={"text": "Simulated high-threat content for testing purposes only. This is a test."},
|
|
|
|
| 114 |
@patch("backend.app.api.routes.detect_harm", new_callable=AsyncMock, return_value=0.02)
|
| 115 |
@patch("backend.app.api.routes.get_cached", new_callable=AsyncMock, return_value=None)
|
| 116 |
@patch("backend.app.api.routes.set_cached", new_callable=AsyncMock)
|
| 117 |
+
@patch("backend.app.db.firestore.save_document", new_callable=AsyncMock, return_value=True)
|
| 118 |
def test_benign_text_low_threat(
|
| 119 |
+
self, mock_save, mock_set_cache, mock_get_cache, mock_harm, mock_upsert,
|
| 120 |
mock_cluster, mock_embed, mock_ai, mock_rate, client
|
| 121 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
response = client.post(
|
| 123 |
"/api/analyze",
|
| 124 |
json={"text": "The weather today is sunny with clear skies and mild temperatures across the region."},
|
backend/tests/test_services.py
CHANGED
|
@@ -12,9 +12,9 @@ from backend.app.services.groq_service import compute_perplexity
|
|
| 12 |
|
| 13 |
class TestHFService:
|
| 14 |
@pytest.mark.asyncio
|
| 15 |
-
@patch("backend.app.services.hf_service.
|
| 16 |
-
async def test_detect_ai_text_success(self,
|
| 17 |
-
|
| 18 |
{"label": "AI", "score": 0.92},
|
| 19 |
{"label": "Human", "score": 0.08},
|
| 20 |
]]
|
|
@@ -22,27 +22,39 @@ class TestHFService:
|
|
| 22 |
assert 0.0 <= score <= 1.0
|
| 23 |
|
| 24 |
@pytest.mark.asyncio
|
| 25 |
-
@patch("backend.app.services.hf_service.
|
| 26 |
-
async def test_detect_ai_text_fallback(self,
|
| 27 |
-
"""If primary fails
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
]
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
| 33 |
assert 0.0 <= score <= 1.0
|
| 34 |
|
| 35 |
@pytest.mark.asyncio
|
| 36 |
-
@patch("backend.app.services.hf_service.
|
| 37 |
-
async def test_get_embeddings_success(self,
|
| 38 |
-
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
assert len(result) == 768
|
| 41 |
|
| 42 |
@pytest.mark.asyncio
|
| 43 |
-
@patch("backend.app.services.hf_service.
|
| 44 |
-
async def test_detect_harm_success(self,
|
| 45 |
-
|
| 46 |
{"label": "hate", "score": 0.15},
|
| 47 |
{"label": "not_hate", "score": 0.85},
|
| 48 |
]]
|
|
@@ -77,7 +89,6 @@ class TestGroqService:
|
|
| 77 |
"usage": {"prompt_tokens": 15},
|
| 78 |
}
|
| 79 |
score = await compute_perplexity("Test text without logprobs available")
|
| 80 |
-
# May return None or a heuristic value
|
| 81 |
assert score is None or 0.0 <= score <= 1.0
|
| 82 |
|
| 83 |
@pytest.mark.asyncio
|
|
|
|
| 12 |
|
| 13 |
class TestHFService:
|
| 14 |
@pytest.mark.asyncio
|
| 15 |
+
@patch("backend.app.services.hf_service._hf_post", new_callable=AsyncMock)
|
| 16 |
+
async def test_detect_ai_text_success(self, mock_post):
|
| 17 |
+
mock_post.return_value = [[
|
| 18 |
{"label": "AI", "score": 0.92},
|
| 19 |
{"label": "Human", "score": 0.08},
|
| 20 |
]]
|
|
|
|
| 22 |
assert 0.0 <= score <= 1.0
|
| 23 |
|
| 24 |
@pytest.mark.asyncio
|
| 25 |
+
@patch("backend.app.services.hf_service._hf_post", new_callable=AsyncMock)
|
| 26 |
+
async def test_detect_ai_text_fallback(self, mock_post):
|
| 27 |
+
"""If primary fails immediately (non-retried error), fallback URL should succeed."""
|
| 28 |
+
# Use ConnectError so tenacity does NOT retry (only HTTPStatusError/ConnectError
|
| 29 |
+
# with stop_after_attempt=3 would retry, but we want ONE failure then fallback).
|
| 30 |
+
# Actually tenacity retries on ConnectError too, so we use a plain Exception
|
| 31 |
+
# which is NOT in the retry predicate — it propagates immediately, letting
|
| 32 |
+
# detect_ai_text catch it in its try/except and move to the fallback URL.
|
| 33 |
+
mock_post.side_effect = [
|
| 34 |
+
Exception("Primary failed"), # primary URL -> caught, move on
|
| 35 |
+
[[{"label": "FAKE", "score": 0.75}]], # fallback URL -> success
|
| 36 |
]
|
| 37 |
+
with patch("backend.app.services.hf_service.settings") as mock_settings:
|
| 38 |
+
mock_settings.HF_DETECTOR_PRIMARY = "https://primary.example.com"
|
| 39 |
+
mock_settings.HF_DETECTOR_FALLBACK = "https://fallback.example.com"
|
| 40 |
+
score = await detect_ai_text("Test text")
|
| 41 |
assert 0.0 <= score <= 1.0
|
| 42 |
|
| 43 |
@pytest.mark.asyncio
|
| 44 |
+
@patch("backend.app.services.hf_service._hf_post", new_callable=AsyncMock)
|
| 45 |
+
async def test_get_embeddings_success(self, mock_post):
|
| 46 |
+
"""Mock returns a 768-dim vector; assert we get back exactly that vector."""
|
| 47 |
+
mock_post.return_value = [0.1] * 768
|
| 48 |
+
with patch("backend.app.services.hf_service.settings") as mock_settings:
|
| 49 |
+
mock_settings.HF_EMBEDDINGS_PRIMARY = "https://embeddings.example.com"
|
| 50 |
+
mock_settings.HF_EMBEDDINGS_FALLBACK = ""
|
| 51 |
+
result = await get_embeddings("Test text")
|
| 52 |
assert len(result) == 768
|
| 53 |
|
| 54 |
@pytest.mark.asyncio
|
| 55 |
+
@patch("backend.app.services.hf_service._hf_post", new_callable=AsyncMock)
|
| 56 |
+
async def test_detect_harm_success(self, mock_post):
|
| 57 |
+
mock_post.return_value = [[
|
| 58 |
{"label": "hate", "score": 0.15},
|
| 59 |
{"label": "not_hate", "score": 0.85},
|
| 60 |
]]
|
|
|
|
| 89 |
"usage": {"prompt_tokens": 15},
|
| 90 |
}
|
| 91 |
score = await compute_perplexity("Test text without logprobs available")
|
|
|
|
| 92 |
assert score is None or 0.0 <= score <= 1.0
|
| 93 |
|
| 94 |
@pytest.mark.asyncio
|