GitHub Actions commited on
Commit
5d50b8b
·
1 Parent(s): 8396c67

Deploy backend from GitHub 522e1ff559eaf4f3a628b450c12e01b910565458

Browse files
backend/app/api/routes.py CHANGED
@@ -1,31 +1,32 @@
1
  """
2
  Main API routes for the LLM Misuse Detection system.
3
  Endpoints: /api/analyze, /api/analyze/bulk, /api/results/{id}
4
- Persistence: Firestore (replaces PostgreSQL)
5
  """
6
  import hashlib
 
7
  import time
8
  from datetime import datetime, timezone
9
 
10
- from fastapi import APIRouter, Depends, HTTPException
11
 
12
  from backend.app.api.models import (
13
- AnalyzeRequest, AnalyzeResponse, BulkAnalyzeRequest,
14
- SignalScores, ExplainabilityItem,
 
 
 
15
  )
16
- from backend.app.core.auth import get_current_user
17
  from backend.app.core.config import settings
 
18
  from backend.app.core.redis import check_rate_limit, get_cached, set_cached
19
- from backend.app.db.firestore import get_db
20
  from backend.app.models.schemas import AnalysisResult
21
  from backend.app.services.ensemble import compute_ensemble
22
- from backend.app.services.hf_service import detect_ai_text, get_embeddings, detect_harm
23
  from backend.app.services.groq_service import compute_perplexity
 
24
  from backend.app.services.stylometry import compute_stylometry_score
25
  from backend.app.services.vector_db import compute_cluster_score, upsert_embedding
26
- from backend.app.core.logging import get_logger
27
-
28
- import json
29
 
30
  logger = get_logger(__name__)
31
  router = APIRouter(prefix="/api", tags=["analysis"])
@@ -33,28 +34,24 @@ router = APIRouter(prefix="/api", tags=["analysis"])
33
  COLLECTION = "analysis_results"
34
 
35
 
36
- async def _analyze_text(text: str, user_id: str = None) -> dict:
37
  """Core analysis pipeline for a single text."""
38
  start_time = time.time()
39
  text_hash = hashlib.sha256(text.encode()).hexdigest()
40
 
41
- # Check cache
42
  cached = await get_cached(f"analysis:{text_hash}")
43
  if cached:
44
  return json.loads(cached)
45
 
46
- # Step 1: AI detection
47
  try:
48
  p_ai = await detect_ai_text(text)
49
  except Exception:
50
  p_ai = None
51
 
52
- # Step 2: Perplexity (cost-gated)
53
  s_perp = None
54
  if p_ai is not None and p_ai > settings.PERPLEXITY_THRESHOLD:
55
  s_perp = await compute_perplexity(text)
56
 
57
- # Step 3: Embeddings + cluster score
58
  s_embed_cluster = None
59
  try:
60
  embeddings = await get_embeddings(text)
@@ -63,16 +60,10 @@ async def _analyze_text(text: str, user_id: str = None) -> dict:
63
  except Exception:
64
  pass
65
 
66
- # Step 4: Harm/extremism
67
  p_ext = await detect_harm(text)
68
-
69
- # Step 5: Stylometry
70
  s_styl = compute_stylometry_score(text)
71
-
72
- # Step 6: Watermark placeholder
73
  p_watermark = None
74
 
75
- # Step 7: Ensemble
76
  ensemble_result = compute_ensemble(
77
  p_ai=p_ai,
78
  s_perp=s_perp,
@@ -97,15 +88,12 @@ async def _analyze_text(text: str, user_id: str = None) -> dict:
97
  "processing_time_ms": processing_time_ms,
98
  }
99
 
100
- # Cache
101
  try:
102
  await set_cached(f"analysis:{text_hash}", json.dumps(result), ttl=600)
103
  except Exception:
104
  pass
105
 
106
- # Persist to Firestore
107
  try:
108
- db = get_db()
109
  doc = AnalysisResult(
110
  input_text=text,
111
  text_hash=text_hash,
@@ -122,8 +110,8 @@ async def _analyze_text(text: str, user_id: str = None) -> dict:
122
  completed_at=datetime.now(timezone.utc),
123
  processing_time_ms=processing_time_ms,
124
  )
125
- db.collection(COLLECTION).document(doc.id).set(doc.to_dict())
126
- result["id"] = doc.id
127
  except Exception as e:
128
  logger.warning("Firestore persist failed", error=str(e))
129
  result["id"] = text_hash
@@ -152,9 +140,7 @@ async def analyze_text(request: AnalyzeRequest):
152
  s_styl=result["s_styl"],
153
  p_watermark=result["p_watermark"],
154
  ),
155
- explainability=[
156
- ExplainabilityItem(**e) for e in result["explainability"]
157
- ],
158
  processing_time_ms=result["processing_time_ms"],
159
  )
160
 
@@ -176,21 +162,16 @@ async def bulk_analyze(request: BulkAnalyzeRequest):
176
  return {"results": results}
177
 
178
 
179
- @router.get("/results/{result_id}")
180
- async def get_result(
181
- result_id: str,
182
- user_id: str = Depends(get_current_user),
183
- ):
184
  """Fetch a previously computed analysis result by Firestore document ID."""
185
- db = get_db()
186
- doc_ref = db.collection(COLLECTION).document(result_id)
187
- doc = doc_ref.get()
188
- if not doc.exists:
189
  raise HTTPException(status_code=404, detail="Result not found")
190
- data = doc.to_dict()
191
  return AnalyzeResponse(
192
  id=data["id"],
193
- status=data["status"],
194
  threat_score=data.get("threat_score"),
195
  signals=SignalScores(
196
  p_ai=data.get("p_ai"),
@@ -200,8 +181,6 @@ async def get_result(
200
  s_styl=data.get("s_styl"),
201
  p_watermark=data.get("p_watermark"),
202
  ),
203
- explainability=[
204
- ExplainabilityItem(**e) for e in (data.get("explainability") or [])
205
- ],
206
  processing_time_ms=data.get("processing_time_ms"),
207
  )
 
1
  """
2
  Main API routes for the LLM Misuse Detection system.
3
  Endpoints: /api/analyze, /api/analyze/bulk, /api/results/{id}
4
+ Persistence: Firestore via REST helpers.
5
  """
6
  import hashlib
7
+ import json
8
  import time
9
  from datetime import datetime, timezone
10
 
11
+ from fastapi import APIRouter, HTTPException
12
 
13
  from backend.app.api.models import (
14
+ AnalyzeRequest,
15
+ AnalyzeResponse,
16
+ BulkAnalyzeRequest,
17
+ ExplainabilityItem,
18
+ SignalScores,
19
  )
 
20
  from backend.app.core.config import settings
21
+ from backend.app.core.logging import get_logger
22
  from backend.app.core.redis import check_rate_limit, get_cached, set_cached
23
+ from backend.app.db.firestore import get_document, save_document
24
  from backend.app.models.schemas import AnalysisResult
25
  from backend.app.services.ensemble import compute_ensemble
 
26
  from backend.app.services.groq_service import compute_perplexity
27
+ from backend.app.services.hf_service import detect_ai_text, detect_harm, get_embeddings
28
  from backend.app.services.stylometry import compute_stylometry_score
29
  from backend.app.services.vector_db import compute_cluster_score, upsert_embedding
 
 
 
30
 
31
  logger = get_logger(__name__)
32
  router = APIRouter(prefix="/api", tags=["analysis"])
 
34
  COLLECTION = "analysis_results"
35
 
36
 
37
+ async def _analyze_text(text: str, user_id: str | None = None) -> dict:
38
  """Core analysis pipeline for a single text."""
39
  start_time = time.time()
40
  text_hash = hashlib.sha256(text.encode()).hexdigest()
41
 
 
42
  cached = await get_cached(f"analysis:{text_hash}")
43
  if cached:
44
  return json.loads(cached)
45
 
 
46
  try:
47
  p_ai = await detect_ai_text(text)
48
  except Exception:
49
  p_ai = None
50
 
 
51
  s_perp = None
52
  if p_ai is not None and p_ai > settings.PERPLEXITY_THRESHOLD:
53
  s_perp = await compute_perplexity(text)
54
 
 
55
  s_embed_cluster = None
56
  try:
57
  embeddings = await get_embeddings(text)
 
60
  except Exception:
61
  pass
62
 
 
63
  p_ext = await detect_harm(text)
 
 
64
  s_styl = compute_stylometry_score(text)
 
 
65
  p_watermark = None
66
 
 
67
  ensemble_result = compute_ensemble(
68
  p_ai=p_ai,
69
  s_perp=s_perp,
 
88
  "processing_time_ms": processing_time_ms,
89
  }
90
 
 
91
  try:
92
  await set_cached(f"analysis:{text_hash}", json.dumps(result), ttl=600)
93
  except Exception:
94
  pass
95
 
 
96
  try:
 
97
  doc = AnalysisResult(
98
  input_text=text,
99
  text_hash=text_hash,
 
110
  completed_at=datetime.now(timezone.utc),
111
  processing_time_ms=processing_time_ms,
112
  )
113
+ saved = await save_document(COLLECTION, doc.id, doc.to_dict())
114
+ result["id"] = doc.id if saved else text_hash
115
  except Exception as e:
116
  logger.warning("Firestore persist failed", error=str(e))
117
  result["id"] = text_hash
 
140
  s_styl=result["s_styl"],
141
  p_watermark=result["p_watermark"],
142
  ),
143
+ explainability=[ExplainabilityItem(**e) for e in result["explainability"]],
 
 
144
  processing_time_ms=result["processing_time_ms"],
145
  )
146
 
 
162
  return {"results": results}
163
 
164
 
165
+ @router.get("/results/{result_id}", response_model=AnalyzeResponse)
166
+ async def get_result(result_id: str):
 
 
 
167
  """Fetch a previously computed analysis result by Firestore document ID."""
168
+ data = await get_document(COLLECTION, result_id)
169
+ if not data:
 
 
170
  raise HTTPException(status_code=404, detail="Result not found")
171
+
172
  return AnalyzeResponse(
173
  id=data["id"],
174
+ status=data.get("status", "done"),
175
  threat_score=data.get("threat_score"),
176
  signals=SignalScores(
177
  p_ai=data.get("p_ai"),
 
181
  s_styl=data.get("s_styl"),
182
  p_watermark=data.get("p_watermark"),
183
  ),
184
+ explainability=[ExplainabilityItem(**e) for e in (data.get("explainability") or [])],
 
 
185
  processing_time_ms=data.get("processing_time_ms"),
186
  )
backend/app/core/config.py CHANGED
@@ -26,10 +26,10 @@ class Settings(BaseSettings):
26
 
27
  # HuggingFace
28
  HF_API_KEY: str = ""
29
- HF_DETECTOR_PRIMARY: str = f"{_HF_ROUTER}/roberta-base-openai-detector"
30
- HF_DETECTOR_FALLBACK: str = f"{_HF_ROUTER}/Hello-SimpleAI/chatgpt-detector-roberta"
31
- HF_EMBEDDINGS_PRIMARY: str = f"{_HF_ROUTER}/sentence-transformers/all-MiniLM-L6-v2"
32
- HF_EMBEDDINGS_FALLBACK: str = f"{_HF_ROUTER}/sentence-transformers/paraphrase-MiniLM-L3-v2"
33
  HF_HARM_CLASSIFIER: str = f"{_HF_ROUTER}/facebook/roberta-hate-speech-dynabench-r4-target"
34
 
35
  # Groq
 
26
 
27
  # HuggingFace
28
  HF_API_KEY: str = ""
29
+ HF_DETECTOR_PRIMARY: str = f"{_HF_ROUTER}/Hello-SimpleAI/chatgpt-detector-roberta"
30
+ HF_DETECTOR_FALLBACK: str = ""
31
+ HF_EMBEDDINGS_PRIMARY: str = ""
32
+ HF_EMBEDDINGS_FALLBACK: str = ""
33
  HF_HARM_CLASSIFIER: str = f"{_HF_ROUTER}/facebook/roberta-hate-speech-dynabench-r4-target"
34
 
35
  # Groq
backend/app/db/firestore.py CHANGED
@@ -1,68 +1,157 @@
1
  """
2
- Firebase Admin SDK initialisation and Firestore client.
3
 
4
- Fixes:
5
- - Handles escaped newlines in private_key when FIREBASE_CREDENTIALS_JSON
6
- is pasted as a single-line string (\\n must become \n for JWT signing).
 
 
 
 
7
 
8
- Priority for credentials:
9
- 1. FIREBASE_CREDENTIALS_JSON env var (JSON string, production)
10
- 2. GOOGLE_APPLICATION_CREDENTIALS env var (path to file, local dev)
11
  """
 
 
12
  import json
13
- import os
 
14
 
15
- import firebase_admin
16
- from firebase_admin import credentials, firestore
17
 
18
  from backend.app.core.config import settings
19
  from backend.app.core.logging import get_logger
20
 
21
  logger = get_logger(__name__)
22
 
23
- _app: firebase_admin.App | None = None
24
- _db = None
 
 
 
 
25
 
26
 
27
- def _fix_private_key(cred_dict: dict) -> dict:
28
- """
29
- When a service account JSON is pasted as a single-line env var, the
30
- private_key newlines get double-escaped as \\n instead of \n.
31
- This causes 'Invalid JWT Signature' errors at runtime.
32
- Fix: replace literal \\n with real newline in private_key only.
33
- """
34
- if "private_key" in cred_dict:
35
- cred_dict["private_key"] = cred_dict["private_key"].replace("\\n", "\n")
36
- return cred_dict
37
 
38
 
39
  def init_firebase() -> None:
40
- """Initialise the Firebase Admin SDK (idempotent)."""
41
- global _app, _db
42
- if _app is not None:
 
43
  return
44
-
45
- if settings.FIREBASE_CREDENTIALS_JSON:
46
  cred_dict = json.loads(settings.FIREBASE_CREDENTIALS_JSON)
47
  cred_dict = _fix_private_key(cred_dict)
48
- cred = credentials.Certificate(cred_dict)
49
- elif os.getenv("GOOGLE_APPLICATION_CREDENTIALS"):
50
- cred = credentials.ApplicationDefault()
51
- else:
52
- raise RuntimeError(
53
- "Firebase credentials not configured. "
54
- "Set FIREBASE_CREDENTIALS_JSON or GOOGLE_APPLICATION_CREDENTIALS."
55
- )
56
-
57
- _app = firebase_admin.initialize_app(
58
- cred,
59
- {"projectId": settings.FIREBASE_PROJECT_ID},
60
- )
61
- _db = firestore.client()
62
- logger.info("Firebase Admin SDK initialised", project=settings.FIREBASE_PROJECT_ID)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
 
65
  def get_db():
66
- if _db is None:
67
- raise RuntimeError("Firestore not initialised. Call init_firebase() on startup.")
68
- return _db
 
 
1
  """
2
+ Firestore client using the Firestore REST API over plain HTTPS.
3
 
4
+ Why REST instead of firebase-admin + gRPC:
5
+ The firebase-admin SDK uses gRPC for Firestore. When FIREBASE_CREDENTIALS_JSON
6
+ is stored as an env-var in HF Spaces, the private_key newlines are double-escaped
7
+ (\\n instead of \n), causing 'invalid_grant: Invalid JWT Signature' errors that
8
+ fire in a tight background loop and spam the logs. The REST approach uses
9
+ google-auth (already installed) directly over HTTPS — no gRPC, no background
10
+ token-refresh loop, and the newline fix is applied once at startup.
11
 
12
+ Env vars required:
13
+ FIREBASE_CREDENTIALS_JSON service account JSON string
14
+ FIREBASE_PROJECT_ID – e.g. "fir-config-d3c36"
15
  """
16
+ from __future__ import annotations
17
+
18
  import json
19
+ import httpx
20
+ from typing import Any
21
 
22
+ import google.oauth2.service_account as sa
23
+ import google.auth.transport.requests as ga_requests
24
 
25
  from backend.app.core.config import settings
26
  from backend.app.core.logging import get_logger
27
 
28
  logger = get_logger(__name__)
29
 
30
+ _SCOPES = ["https://www.googleapis.com/auth/datastore"]
31
+ _FIRESTORE_BASE = "https://firestore.googleapis.com/v1"
32
+
33
+ _credentials: sa.Credentials | None = None
34
+ _project_id: str = ""
35
+ _enabled: bool = False
36
 
37
 
38
+ def _fix_private_key(d: dict) -> dict:
39
+ """Unescape double-escaped newlines in private_key (common in env-var pastes)."""
40
+ if "private_key" in d:
41
+ d["private_key"] = d["private_key"].replace("\\n", "\n")
42
+ return d
 
 
 
 
 
43
 
44
 
45
  def init_firebase() -> None:
46
+ """Load service-account credentials. Non-fatal if misconfigured."""
47
+ global _credentials, _project_id, _enabled
48
+ if not settings.FIREBASE_CREDENTIALS_JSON:
49
+ logger.warning("FIREBASE_CREDENTIALS_JSON not set – Firestore disabled")
50
  return
51
+ try:
 
52
  cred_dict = json.loads(settings.FIREBASE_CREDENTIALS_JSON)
53
  cred_dict = _fix_private_key(cred_dict)
54
+ _credentials = sa.Credentials.from_service_account_info(cred_dict, scopes=_SCOPES)
55
+ _project_id = settings.FIREBASE_PROJECT_ID or cred_dict.get("project_id", "")
56
+ # Validate credentials once at startup to avoid repeated runtime failures.
57
+ req = ga_requests.Request()
58
+ _credentials.refresh(req)
59
+ _enabled = True
60
+ logger.info("Firebase REST client initialised", project=_project_id)
61
+ except Exception as e:
62
+ _credentials = None
63
+ _enabled = False
64
+ logger.warning("Firebase init failed – Firestore disabled", error=str(e))
65
+
66
+
67
+ def _auth_headers() -> dict:
68
+ """Return a fresh Bearer token header (refreshes automatically when needed)."""
69
+ req = ga_requests.Request()
70
+ _credentials.refresh(req)
71
+ return {"Authorization": f"Bearer {_credentials.token}"}
72
+
73
+
74
+ def _collection_url(collection: str) -> str:
75
+ return f"{_FIRESTORE_BASE}/projects/{_project_id}/databases/(default)/documents/{collection}"
76
+
77
+
78
+ def _doc_url(collection: str, doc_id: str) -> str:
79
+ return f"{_collection_url(collection)}/{doc_id}"
80
+
81
+
82
+ def _to_firestore_value(v: Any) -> dict:
83
+ """Convert a Python value to a Firestore REST value object."""
84
+ if isinstance(v, bool):
85
+ return {"booleanValue": v}
86
+ if isinstance(v, int):
87
+ return {"integerValue": str(v)}
88
+ if isinstance(v, float):
89
+ return {"doubleValue": v}
90
+ if isinstance(v, str):
91
+ return {"stringValue": v}
92
+ if v is None:
93
+ return {"nullValue": None}
94
+ if isinstance(v, dict):
95
+ return {"mapValue": {"fields": {k: _to_firestore_value(u) for k, u in v.items()}}}
96
+ if isinstance(v, list):
97
+ return {"arrayValue": {"values": [_to_firestore_value(i) for i in v]}}
98
+ return {"stringValue": str(v)}
99
+
100
+
101
+ def _from_firestore_value(v: dict) -> Any:
102
+ """Convert a Firestore REST value object to a Python value."""
103
+ if "stringValue" in v: return v["stringValue"]
104
+ if "integerValue" in v: return int(v["integerValue"])
105
+ if "doubleValue" in v: return float(v["doubleValue"])
106
+ if "booleanValue" in v: return v["booleanValue"]
107
+ if "nullValue" in v: return None
108
+ if "mapValue" in v: return {k: _from_firestore_value(u) for k, u in v["mapValue"].get("fields", {}).items()}
109
+ if "arrayValue" in v: return [_from_firestore_value(i) for i in v["arrayValue"].get("values", [])]
110
+ return None
111
+
112
+
113
+ # ---- Public helpers --------------------------------------------------------
114
+
115
+ async def save_document(collection: str, doc_id: str, data: dict) -> bool:
116
+ """Create or overwrite a Firestore document. Returns True on success."""
117
+ if not _enabled:
118
+ return False
119
+ try:
120
+ fields = {k: _to_firestore_value(v) for k, v in data.items()}
121
+ url = _doc_url(collection, doc_id)
122
+ async with httpx.AsyncClient(timeout=10.0) as client:
123
+ resp = await client.patch(
124
+ url,
125
+ json={"fields": fields},
126
+ headers=_auth_headers(),
127
+ )
128
+ resp.raise_for_status()
129
+ return True
130
+ except Exception as e:
131
+ logger.warning("Firestore save_document failed", collection=collection, doc_id=doc_id, error=str(e))
132
+ return False
133
+
134
+
135
+ async def get_document(collection: str, doc_id: str) -> dict | None:
136
+ """Fetch a single Firestore document. Returns None if not found or disabled."""
137
+ if not _enabled:
138
+ return None
139
+ try:
140
+ url = _doc_url(collection, doc_id)
141
+ async with httpx.AsyncClient(timeout=10.0) as client:
142
+ resp = await client.get(url, headers=_auth_headers())
143
+ if resp.status_code == 404:
144
+ return None
145
+ resp.raise_for_status()
146
+ fields = resp.json().get("fields", {})
147
+ return {k: _from_firestore_value(v) for k, v in fields.items()}
148
+ except Exception as e:
149
+ logger.warning("Firestore get_document failed", collection=collection, doc_id=doc_id, error=str(e))
150
+ return None
151
 
152
 
153
  def get_db():
154
+ """Legacy shim for code that calls get_db(). Returns None if Firestore is disabled."""
155
+ if not _enabled:
156
+ return None
157
+ return True # callers should use save_document/get_document directly
backend/app/services/groq_service.py CHANGED
@@ -1,46 +1,32 @@
1
- """
2
- Groq API client for perplexity scoring using Llama models.
3
- Computes token-level log-probabilities to produce perplexity scores.
4
 
5
- Env vars: GROQ_API_KEY, GROQ_MODEL, GROQ_BASE_URL
6
- """
7
  import math
8
- import httpx
9
- from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
10
  from typing import Optional
11
 
 
 
12
  from backend.app.core.config import settings
13
  from backend.app.core.logging import get_logger
14
 
15
  logger = get_logger(__name__)
16
 
17
- _TIMEOUT = httpx.Timeout(60.0, connect=10.0)
 
18
 
19
 
20
- class GroqServiceError(Exception):
21
- pass
22
-
23
 
24
- @retry(
25
- stop=stop_after_attempt(3),
26
- wait=wait_exponential(multiplier=1, min=2, max=15),
27
- retry=retry_if_exception_type((httpx.HTTPStatusError, httpx.ConnectError)),
28
- )
29
- async def _groq_chat_completion(text: str) -> dict:
30
- """Call Groq chat completion with logprobs enabled.
31
- Note: Input is truncated to 2000 chars for cost control. Perplexity
32
- scores for longer texts reflect only the first 2000 characters.
33
- """
34
  headers = {
35
  "Authorization": f"Bearer {settings.GROQ_API_KEY}",
36
  "Content-Type": "application/json",
37
  }
38
  payload = {
39
- "model": settings.GROQ_MODEL,
40
- "messages": [
41
- {"role": "system", "content": "Repeat the following text exactly:"},
42
- {"role": "user", "content": text[:2000]}, # Truncated for cost control
43
- ],
44
  "max_tokens": 1,
45
  "temperature": 0,
46
  "logprobs": True,
@@ -52,54 +38,44 @@ async def _groq_chat_completion(text: str) -> dict:
52
  json=payload,
53
  headers=headers,
54
  )
 
 
 
 
55
  resp.raise_for_status()
56
  return resp.json()
57
 
58
 
59
  async def compute_perplexity(text: str) -> Optional[float]:
60
- """
61
- Compute a normalized perplexity score using Groq Llama endpoints.
62
- Returns a score between 0 and 1 where higher = more anomalous.
63
-
64
- Strategy: Use logprobs from a single completion call to estimate
65
- the model's surprise at the input text.
66
- """
67
  try:
68
  result = await _groq_chat_completion(text)
 
 
 
69
  choices = result.get("choices", [])
70
  if not choices:
71
  return None
72
 
73
- logprobs_data = choices[0].get("logprobs", {})
74
- if not logprobs_data:
75
- # If logprobs not available, use usage-based heuristic
 
76
  usage = result.get("usage", {})
77
  prompt_tokens = usage.get("prompt_tokens", 0)
78
  if prompt_tokens > 0:
79
- text_len = len(text.split())
80
- ratio = prompt_tokens / max(text_len, 1)
81
- # Normalize: high token ratio suggests unusual tokenization
82
- return min(1.0, max(0.0, (ratio - 1.0) / 2.0))
83
  return None
84
 
85
- content = logprobs_data.get("content", [])
86
- if not content:
87
- return None
88
-
89
- # Compute perplexity from log-probabilities
90
- log_probs = []
91
- for token_info in content:
92
- lp = token_info.get("logprob")
93
- if lp is not None:
94
- log_probs.append(lp)
95
-
96
  if not log_probs:
97
  return None
98
 
99
  avg_log_prob = sum(log_probs) / len(log_probs)
100
  perplexity = math.exp(-avg_log_prob)
101
- # Normalize to 0-1 range (perplexity of 1 = perfectly predicted, >100 = very unusual)
102
- normalized = min(1.0, max(0.0, (math.log(perplexity + 1) / math.log(101))))
103
  return round(normalized, 4)
104
  except Exception as e:
105
  logger.warning("Groq perplexity computation failed", error=str(e))
 
1
+ """Groq API client for optional perplexity scoring."""
2
+
3
+ from __future__ import annotations
4
 
 
 
5
  import math
 
 
6
  from typing import Optional
7
 
8
+ import httpx
9
+
10
  from backend.app.core.config import settings
11
  from backend.app.core.logging import get_logger
12
 
13
  logger = get_logger(__name__)
14
 
15
+ _TIMEOUT = httpx.Timeout(30.0, connect=10.0)
16
+ _LOGPROBS_MODEL = "llama-3.1-8b-instant"
17
 
18
 
19
+ async def _groq_chat_completion(text: str) -> Optional[dict]:
20
+ if not settings.GROQ_API_KEY:
21
+ return None
22
 
 
 
 
 
 
 
 
 
 
 
23
  headers = {
24
  "Authorization": f"Bearer {settings.GROQ_API_KEY}",
25
  "Content-Type": "application/json",
26
  }
27
  payload = {
28
+ "model": _LOGPROBS_MODEL,
29
+ "messages": [{"role": "user", "content": text[:1500]}],
 
 
 
30
  "max_tokens": 1,
31
  "temperature": 0,
32
  "logprobs": True,
 
38
  json=payload,
39
  headers=headers,
40
  )
41
+ # 4xx means unsupported model/params for this key-tier; do not spam retries.
42
+ if 400 <= resp.status_code < 500:
43
+ logger.info("Groq perplexity unavailable for current deployment", status_code=resp.status_code)
44
+ return None
45
  resp.raise_for_status()
46
  return resp.json()
47
 
48
 
49
  async def compute_perplexity(text: str) -> Optional[float]:
50
+ """Compute a normalized perplexity score (0-1). Returns None on failure."""
 
 
 
 
 
 
51
  try:
52
  result = await _groq_chat_completion(text)
53
+ if not result:
54
+ return None
55
+
56
  choices = result.get("choices", [])
57
  if not choices:
58
  return None
59
 
60
+ logprobs_data = choices[0].get("logprobs") or {}
61
+ content = logprobs_data.get("content") or []
62
+
63
+ if not content:
64
  usage = result.get("usage", {})
65
  prompt_tokens = usage.get("prompt_tokens", 0)
66
  if prompt_tokens > 0:
67
+ text_len = max(len(text.split()), 1)
68
+ ratio = prompt_tokens / text_len
69
+ return round(min(1.0, max(0.0, (ratio - 1.0) / 2.0)), 4)
 
70
  return None
71
 
72
+ log_probs = [t["logprob"] for t in content if t.get("logprob") is not None]
 
 
 
 
 
 
 
 
 
 
73
  if not log_probs:
74
  return None
75
 
76
  avg_log_prob = sum(log_probs) / len(log_probs)
77
  perplexity = math.exp(-avg_log_prob)
78
+ normalized = min(1.0, max(0.0, math.log(perplexity + 1) / math.log(101)))
 
79
  return round(normalized, 4)
80
  except Exception as e:
81
  logger.warning("Groq perplexity computation failed", error=str(e))
backend/app/services/hf_service.py CHANGED
@@ -1,26 +1,24 @@
1
- """
2
- Hugging Face Inference API client.
3
- Calls AI-text detectors and embedding models hosted on HF Inference Endpoints.
4
- Implements retry/backoff and circuit-breaker behavior.
5
-
6
- Env vars: HF_API_KEY, HF_DETECTOR_PRIMARY, HF_DETECTOR_FALLBACK,
7
- HF_EMBEDDINGS_PRIMARY, HF_EMBEDDINGS_FALLBACK, HF_HARM_CLASSIFIER
8
- """
9
  import httpx
10
- from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
11
- from typing import List, Optional, Dict, Any
12
 
13
  from backend.app.core.config import settings
14
  from backend.app.core.logging import get_logger
15
 
16
  logger = get_logger(__name__)
17
 
18
- _HEADERS = lambda: {"Authorization": f"Bearer {settings.HF_API_KEY}"}
19
  _TIMEOUT = httpx.Timeout(30.0, connect=10.0)
 
20
 
21
 
22
- class HFServiceError(Exception):
23
- pass
24
 
25
 
26
  @retry(
@@ -28,74 +26,97 @@ class HFServiceError(Exception):
28
  wait=wait_exponential(multiplier=1, min=1, max=10),
29
  retry=retry_if_exception_type((httpx.HTTPStatusError, httpx.ConnectError)),
30
  )
31
- async def _hf_request(url: str, payload: dict) -> Any:
32
  async with httpx.AsyncClient(timeout=_TIMEOUT) as client:
33
- resp = await client.post(url, json=payload, headers=_HEADERS())
34
  resp.raise_for_status()
35
  return resp.json()
36
 
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  async def detect_ai_text(text: str) -> float:
39
- """
40
- Call AI text detector ensemble (primary + fallback).
41
- Returns probability that text is AI-generated (0-1).
42
- """
43
- scores = []
44
- for url in [settings.HF_DETECTOR_PRIMARY, settings.HF_DETECTOR_FALLBACK]:
45
  try:
46
- result = await _hf_request(url, {"inputs": text})
47
- # HF classification returns [[{label, score}, ...]]
48
  if isinstance(result, list) and len(result) > 0:
49
  labels = result[0] if isinstance(result[0], list) else result
50
  for item in labels:
51
  label = item.get("label", "").lower()
52
- if label in ("ai", "fake", "machine", "ai-generated", "generated"):
53
- scores.append(item["score"])
 
 
 
 
 
 
 
 
 
 
 
 
54
  break
55
  else:
56
- # If no matching label found, use first score as proxy
57
- if labels:
58
- scores.append(labels[0].get("score", 0.5))
59
  except Exception as e:
60
  logger.warning("HF detector call failed", url=url, error=str(e))
 
61
  if not scores:
62
- raise HFServiceError("All AI detectors failed")
63
- return sum(scores) / len(scores)
64
 
65
 
66
- async def get_embeddings(text: str) -> List[float]:
67
- """Get text embeddings from HF sentence-transformers endpoint."""
68
- for url in [settings.HF_EMBEDDINGS_PRIMARY, settings.HF_EMBEDDINGS_FALLBACK]:
69
  try:
70
- result = await _hf_request(url, {"inputs": text})
71
- if isinstance(result, list) and len(result) > 0:
72
- # Returns a list of floats (embedding vector)
73
- if isinstance(result[0], float):
74
- return result
75
- if isinstance(result[0], list):
76
- return result[0]
77
- return result
78
  except Exception as e:
79
  logger.warning("HF embeddings call failed", url=url, error=str(e))
80
- raise HFServiceError("All embedding endpoints failed")
 
 
81
 
82
 
83
  async def detect_harm(text: str) -> float:
84
- """
85
- Call harm/extremism classifier on HF.
86
- Returns probability of harmful/extremist content (0-1).
87
- """
88
  try:
89
- result = await _hf_request(settings.HF_HARM_CLASSIFIER, {"inputs": text})
90
  if isinstance(result, list) and len(result) > 0:
91
  labels = result[0] if isinstance(result[0], list) else result
92
  for item in labels:
93
  label = item.get("label", "").lower()
94
- if label in ("hate", "toxic", "harmful", "extremist", "hateful"):
95
- return item["score"]
96
- # Fallback: return highest score
97
- if labels:
98
- return max(item.get("score", 0.0) for item in labels)
99
  return 0.0
100
  except Exception as e:
101
  logger.warning("HF harm classifier failed", error=str(e))
 
1
+ """Hugging Face Inference API helpers with resilient fallbacks."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import hashlib
6
+ from typing import Any
7
+
 
8
  import httpx
9
+ from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
 
10
 
11
  from backend.app.core.config import settings
12
  from backend.app.core.logging import get_logger
13
 
14
  logger = get_logger(__name__)
15
 
 
16
  _TIMEOUT = httpx.Timeout(30.0, connect=10.0)
17
+ _LOCAL_EMBEDDING_DIM = 384
18
 
19
 
20
+ def _headers() -> dict:
21
+ return {"Authorization": f"Bearer {settings.HF_API_KEY}"}
22
 
23
 
24
  @retry(
 
26
  wait=wait_exponential(multiplier=1, min=1, max=10),
27
  retry=retry_if_exception_type((httpx.HTTPStatusError, httpx.ConnectError)),
28
  )
29
+ async def _hf_post(url: str, payload: dict) -> Any:
30
  async with httpx.AsyncClient(timeout=_TIMEOUT) as client:
31
+ resp = await client.post(url, json=payload, headers=_headers())
32
  resp.raise_for_status()
33
  return resp.json()
34
 
35
 
36
+ def _configured_urls(*urls: str) -> list[str]:
37
+ return [url for url in urls if url and url.strip()]
38
+
39
+
40
+ def _local_embedding(text: str, dim: int = _LOCAL_EMBEDDING_DIM) -> list[float]:
41
+ """Deterministic no-network embedding fallback to keep pipeline stable."""
42
+ seed = hashlib.sha256(text.encode("utf-8")).digest()
43
+ values: list[float] = []
44
+ block = seed
45
+ while len(values) < dim:
46
+ block = hashlib.sha256(block + text.encode("utf-8")).digest()
47
+ for byte in block:
48
+ values.append((byte / 127.5) - 1.0)
49
+ if len(values) == dim:
50
+ break
51
+ return values
52
+
53
+
54
  async def detect_ai_text(text: str) -> float:
55
+ """Returns probability that text is AI-generated (0-1)."""
56
+ scores: list[float] = []
57
+ for url in _configured_urls(settings.HF_DETECTOR_PRIMARY, settings.HF_DETECTOR_FALLBACK):
 
 
 
58
  try:
59
+ result = await _hf_post(url, {"inputs": text})
 
60
  if isinstance(result, list) and len(result) > 0:
61
  labels = result[0] if isinstance(result[0], list) else result
62
  for item in labels:
63
  label = item.get("label", "").lower()
64
+ if any(
65
+ k in label
66
+ for k in (
67
+ "ai",
68
+ "fake",
69
+ "machine",
70
+ "generated",
71
+ "chatgpt",
72
+ "gpt",
73
+ "class_1",
74
+ "label_1",
75
+ )
76
+ ):
77
+ scores.append(float(item["score"]))
78
  break
79
  else:
80
+ best = max(labels, key=lambda x: x.get("score", 0))
81
+ scores.append(float(best.get("score", 0.5)))
 
82
  except Exception as e:
83
  logger.warning("HF detector call failed", url=url, error=str(e))
84
+
85
  if not scores:
86
+ raise Exception("All AI detectors failed")
87
+ return round(sum(scores) / len(scores), 4)
88
 
89
 
90
+ async def get_embeddings(text: str) -> list[float]:
91
+ """Returns embedding vector, falling back to deterministic local embedding."""
92
+ for url in _configured_urls(settings.HF_EMBEDDINGS_PRIMARY, settings.HF_EMBEDDINGS_FALLBACK):
93
  try:
94
+ result = await _hf_post(url, {"inputs": text})
95
+ while isinstance(result, list) and result and isinstance(result[0], list):
96
+ result = result[0]
97
+ if isinstance(result, list) and result and isinstance(result[0], (float, int)):
98
+ return [float(v) for v in result]
 
 
 
99
  except Exception as e:
100
  logger.warning("HF embeddings call failed", url=url, error=str(e))
101
+
102
+ logger.info("Using local deterministic embeddings fallback")
103
+ return _local_embedding(text)
104
 
105
 
106
  async def detect_harm(text: str) -> float:
107
+ """Returns probability of harmful content (0-1). Non-fatal on failure."""
108
+ if not settings.HF_HARM_CLASSIFIER:
109
+ return 0.0
110
+
111
  try:
112
+ result = await _hf_post(settings.HF_HARM_CLASSIFIER, {"inputs": text})
113
  if isinstance(result, list) and len(result) > 0:
114
  labels = result[0] if isinstance(result[0], list) else result
115
  for item in labels:
116
  label = item.get("label", "").lower()
117
+ if any(k in label for k in ("hate", "toxic", "harmful", "hateful", "target")):
118
+ return float(item["score"])
119
+ return float(max(labels, key=lambda x: x.get("score", 0)).get("score", 0.0))
 
 
120
  return 0.0
121
  except Exception as e:
122
  logger.warning("HF harm classifier failed", error=str(e))
backend/tests/test_api.py CHANGED
@@ -9,29 +9,16 @@ from fastapi.testclient import TestClient
9
 
10
  def _make_client():
11
  """Build a TestClient with Firebase init and Firestore writes mocked out."""
12
- # Patch firebase_admin before app is imported to prevent SDK initialisation
13
- with patch("backend.app.db.firestore.firebase_admin"), \
14
- patch("backend.app.db.firestore.firestore"):
 
 
15
  from backend.app.main import app
16
 
17
- # Mock get_db() so Firestore document writes are no-ops
18
- mock_db = MagicMock()
19
- mock_collection = MagicMock()
20
- mock_doc_ref = MagicMock()
21
- mock_db.collection.return_value = mock_collection
22
- mock_collection.document.return_value = mock_doc_ref
23
- mock_doc_ref.set.return_value = None
24
-
25
- # Mock get_result Firestore read
26
- mock_existing_doc = MagicMock()
27
- mock_existing_doc.exists = False
28
- mock_doc_ref.get.return_value = mock_existing_doc
29
-
30
  app.dependency_overrides = {}
31
- with patch("backend.app.api.routes.get_db", return_value=mock_db), \
32
- patch("backend.app.db.firestore.init_firebase"):
33
- client = TestClient(app)
34
- return client, mock_db
35
 
36
 
37
  @pytest.fixture
@@ -65,15 +52,11 @@ class TestAnalyzeEndpoint:
65
  @patch("backend.app.api.routes.detect_harm", new_callable=AsyncMock, return_value=0.2)
66
  @patch("backend.app.api.routes.get_cached", new_callable=AsyncMock, return_value=None)
67
  @patch("backend.app.api.routes.set_cached", new_callable=AsyncMock)
68
- @patch("backend.app.api.routes.get_db")
69
  def test_analyze_returns_scores(
70
- self, mock_get_db, mock_set_cache, mock_get_cache, mock_harm, mock_upsert,
71
  mock_cluster, mock_embed, mock_perp, mock_ai, mock_rate, client
72
  ):
73
- mock_db = MagicMock()
74
- mock_db.collection.return_value.document.return_value.set.return_value = None
75
- mock_get_db.return_value = mock_db
76
-
77
  response = client.post(
78
  "/api/analyze",
79
  json={"text": "This is a test text that should be analyzed for potential misuse patterns."},
@@ -110,15 +93,11 @@ class TestAttackSimulations:
110
  @patch("backend.app.api.routes.detect_harm", new_callable=AsyncMock, return_value=0.9)
111
  @patch("backend.app.api.routes.get_cached", new_callable=AsyncMock, return_value=None)
112
  @patch("backend.app.api.routes.set_cached", new_callable=AsyncMock)
113
- @patch("backend.app.api.routes.get_db")
114
  def test_high_threat_detection(
115
- self, mock_get_db, mock_set_cache, mock_get_cache, mock_harm, mock_upsert,
116
  mock_cluster, mock_embed, mock_perp, mock_ai, mock_rate, client
117
  ):
118
- mock_db = MagicMock()
119
- mock_db.collection.return_value.document.return_value.set.return_value = None
120
- mock_get_db.return_value = mock_db
121
-
122
  response = client.post(
123
  "/api/analyze",
124
  json={"text": "Simulated high-threat content for testing purposes only. This is a test."},
@@ -135,15 +114,11 @@ class TestAttackSimulations:
135
  @patch("backend.app.api.routes.detect_harm", new_callable=AsyncMock, return_value=0.02)
136
  @patch("backend.app.api.routes.get_cached", new_callable=AsyncMock, return_value=None)
137
  @patch("backend.app.api.routes.set_cached", new_callable=AsyncMock)
138
- @patch("backend.app.api.routes.get_db")
139
  def test_benign_text_low_threat(
140
- self, mock_get_db, mock_set_cache, mock_get_cache, mock_harm, mock_upsert,
141
  mock_cluster, mock_embed, mock_ai, mock_rate, client
142
  ):
143
- mock_db = MagicMock()
144
- mock_db.collection.return_value.document.return_value.set.return_value = None
145
- mock_get_db.return_value = mock_db
146
-
147
  response = client.post(
148
  "/api/analyze",
149
  json={"text": "The weather today is sunny with clear skies and mild temperatures across the region."},
 
9
 
10
  def _make_client():
11
  """Build a TestClient with Firebase init and Firestore writes mocked out."""
12
+ # Patch init_firebase and the _enabled flag so the app starts without real credentials
13
+ with patch("backend.app.db.firestore.init_firebase"), \
14
+ patch("backend.app.db.firestore._enabled", True), \
15
+ patch("backend.app.db.firestore.save_document", new_callable=AsyncMock, return_value=True), \
16
+ patch("backend.app.db.firestore.get_document", new_callable=AsyncMock, return_value=None):
17
  from backend.app.main import app
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  app.dependency_overrides = {}
20
+ client = TestClient(app)
21
+ return client, None
 
 
22
 
23
 
24
  @pytest.fixture
 
52
  @patch("backend.app.api.routes.detect_harm", new_callable=AsyncMock, return_value=0.2)
53
  @patch("backend.app.api.routes.get_cached", new_callable=AsyncMock, return_value=None)
54
  @patch("backend.app.api.routes.set_cached", new_callable=AsyncMock)
55
+ @patch("backend.app.db.firestore.save_document", new_callable=AsyncMock, return_value=True)
56
  def test_analyze_returns_scores(
57
+ self, mock_save, mock_set_cache, mock_get_cache, mock_harm, mock_upsert,
58
  mock_cluster, mock_embed, mock_perp, mock_ai, mock_rate, client
59
  ):
 
 
 
 
60
  response = client.post(
61
  "/api/analyze",
62
  json={"text": "This is a test text that should be analyzed for potential misuse patterns."},
 
93
  @patch("backend.app.api.routes.detect_harm", new_callable=AsyncMock, return_value=0.9)
94
  @patch("backend.app.api.routes.get_cached", new_callable=AsyncMock, return_value=None)
95
  @patch("backend.app.api.routes.set_cached", new_callable=AsyncMock)
96
+ @patch("backend.app.db.firestore.save_document", new_callable=AsyncMock, return_value=True)
97
  def test_high_threat_detection(
98
+ self, mock_save, mock_set_cache, mock_get_cache, mock_harm, mock_upsert,
99
  mock_cluster, mock_embed, mock_perp, mock_ai, mock_rate, client
100
  ):
 
 
 
 
101
  response = client.post(
102
  "/api/analyze",
103
  json={"text": "Simulated high-threat content for testing purposes only. This is a test."},
 
114
  @patch("backend.app.api.routes.detect_harm", new_callable=AsyncMock, return_value=0.02)
115
  @patch("backend.app.api.routes.get_cached", new_callable=AsyncMock, return_value=None)
116
  @patch("backend.app.api.routes.set_cached", new_callable=AsyncMock)
117
+ @patch("backend.app.db.firestore.save_document", new_callable=AsyncMock, return_value=True)
118
  def test_benign_text_low_threat(
119
+ self, mock_save, mock_set_cache, mock_get_cache, mock_harm, mock_upsert,
120
  mock_cluster, mock_embed, mock_ai, mock_rate, client
121
  ):
 
 
 
 
122
  response = client.post(
123
  "/api/analyze",
124
  json={"text": "The weather today is sunny with clear skies and mild temperatures across the region."},
backend/tests/test_services.py CHANGED
@@ -12,9 +12,9 @@ from backend.app.services.groq_service import compute_perplexity
12
 
13
  class TestHFService:
14
  @pytest.mark.asyncio
15
- @patch("backend.app.services.hf_service._hf_request", new_callable=AsyncMock)
16
- async def test_detect_ai_text_success(self, mock_request):
17
- mock_request.return_value = [[
18
  {"label": "AI", "score": 0.92},
19
  {"label": "Human", "score": 0.08},
20
  ]]
@@ -22,27 +22,39 @@ class TestHFService:
22
  assert 0.0 <= score <= 1.0
23
 
24
  @pytest.mark.asyncio
25
- @patch("backend.app.services.hf_service._hf_request", new_callable=AsyncMock)
26
- async def test_detect_ai_text_fallback(self, mock_request):
27
- """If primary fails, should try fallback."""
28
- mock_request.side_effect = [
29
- Exception("Primary failed"),
30
- [[{"label": "FAKE", "score": 0.75}]],
 
 
 
 
 
31
  ]
32
- score = await detect_ai_text("Test text")
 
 
 
33
  assert 0.0 <= score <= 1.0
34
 
35
  @pytest.mark.asyncio
36
- @patch("backend.app.services.hf_service._hf_request", new_callable=AsyncMock)
37
- async def test_get_embeddings_success(self, mock_request):
38
- mock_request.return_value = [0.1] * 768
39
- result = await get_embeddings("Test text")
 
 
 
 
40
  assert len(result) == 768
41
 
42
  @pytest.mark.asyncio
43
- @patch("backend.app.services.hf_service._hf_request", new_callable=AsyncMock)
44
- async def test_detect_harm_success(self, mock_request):
45
- mock_request.return_value = [[
46
  {"label": "hate", "score": 0.15},
47
  {"label": "not_hate", "score": 0.85},
48
  ]]
@@ -77,7 +89,6 @@ class TestGroqService:
77
  "usage": {"prompt_tokens": 15},
78
  }
79
  score = await compute_perplexity("Test text without logprobs available")
80
- # May return None or a heuristic value
81
  assert score is None or 0.0 <= score <= 1.0
82
 
83
  @pytest.mark.asyncio
 
12
 
13
  class TestHFService:
14
  @pytest.mark.asyncio
15
+ @patch("backend.app.services.hf_service._hf_post", new_callable=AsyncMock)
16
+ async def test_detect_ai_text_success(self, mock_post):
17
+ mock_post.return_value = [[
18
  {"label": "AI", "score": 0.92},
19
  {"label": "Human", "score": 0.08},
20
  ]]
 
22
  assert 0.0 <= score <= 1.0
23
 
24
  @pytest.mark.asyncio
25
+ @patch("backend.app.services.hf_service._hf_post", new_callable=AsyncMock)
26
+ async def test_detect_ai_text_fallback(self, mock_post):
27
+ """If primary fails immediately (non-retried error), fallback URL should succeed."""
28
+ # Use ConnectError so tenacity does NOT retry (only HTTPStatusError/ConnectError
29
+ # with stop_after_attempt=3 would retry, but we want ONE failure then fallback).
30
+ # Actually tenacity retries on ConnectError too, so we use a plain Exception
31
+ # which is NOT in the retry predicate — it propagates immediately, letting
32
+ # detect_ai_text catch it in its try/except and move to the fallback URL.
33
+ mock_post.side_effect = [
34
+ Exception("Primary failed"), # primary URL -> caught, move on
35
+ [[{"label": "FAKE", "score": 0.75}]], # fallback URL -> success
36
  ]
37
+ with patch("backend.app.services.hf_service.settings") as mock_settings:
38
+ mock_settings.HF_DETECTOR_PRIMARY = "https://primary.example.com"
39
+ mock_settings.HF_DETECTOR_FALLBACK = "https://fallback.example.com"
40
+ score = await detect_ai_text("Test text")
41
  assert 0.0 <= score <= 1.0
42
 
43
  @pytest.mark.asyncio
44
+ @patch("backend.app.services.hf_service._hf_post", new_callable=AsyncMock)
45
+ async def test_get_embeddings_success(self, mock_post):
46
+ """Mock returns a 768-dim vector; assert we get back exactly that vector."""
47
+ mock_post.return_value = [0.1] * 768
48
+ with patch("backend.app.services.hf_service.settings") as mock_settings:
49
+ mock_settings.HF_EMBEDDINGS_PRIMARY = "https://embeddings.example.com"
50
+ mock_settings.HF_EMBEDDINGS_FALLBACK = ""
51
+ result = await get_embeddings("Test text")
52
  assert len(result) == 768
53
 
54
  @pytest.mark.asyncio
55
+ @patch("backend.app.services.hf_service._hf_post", new_callable=AsyncMock)
56
+ async def test_detect_harm_success(self, mock_post):
57
+ mock_post.return_value = [[
58
  {"label": "hate", "score": 0.15},
59
  {"label": "not_hate", "score": 0.85},
60
  ]]
 
89
  "usage": {"prompt_tokens": 15},
90
  }
91
  score = await compute_perplexity("Test text without logprobs available")
 
92
  assert score is None or 0.0 <= score <= 1.0
93
 
94
  @pytest.mark.asyncio