asahwells commited on
Commit
87cc0f2
·
1 Parent(s): 18e9c9a

Initial commit with project setup and basic structure established.

Browse files
app/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """
2
+ Application package for the low-latency moderation API.
3
+ """
4
+
5
+ __all__ = ["config", "models", "services"]
6
+
app/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (280 Bytes). View file
 
app/__pycache__/config.cpython-313.pyc ADDED
Binary file (1.8 kB). View file
 
app/__pycache__/main.cpython-313.pyc ADDED
Binary file (2.1 kB). View file
 
app/__pycache__/models.cpython-313.pyc ADDED
Binary file (1.6 kB). View file
 
app/__pycache__/services.cpython-313.pyc ADDED
Binary file (2.45 kB). View file
 
app/config.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration helpers for the moderation service.
3
+ """
4
+
5
+ from dataclasses import dataclass
6
+ from functools import lru_cache
7
+ import os
8
+
9
+
10
+ @dataclass(frozen=True)
11
+ class Settings:
12
+ """Simple immutable settings object."""
13
+
14
+ model_name: str = "martin-ha/toxic-comment-model"
15
+ negative_threshold: float = 0.90
16
+ api_title: str = "Low-Latency Moderation API"
17
+ api_description: str = (
18
+ "Intercepts chat messages and blocks high-confidence toxic content in under 150ms."
19
+ )
20
+ api_version: str = "1.0.0"
21
+
22
+
23
+ @lru_cache(maxsize=1)
24
+ def get_settings() -> Settings:
25
+ """
26
+ Load settings once, allowing environment variables to override defaults.
27
+
28
+ Environment variables:
29
+ MODEL_NAME: Hugging Face model identifier.
30
+ NEGATIVE_THRESHOLD: Float between 0 and 1 for blocking TOXIC messages.
31
+ """
32
+
33
+ model_name = os.getenv("MODEL_NAME", Settings.model_name)
34
+ negative_threshold = float(os.getenv("NEGATIVE_THRESHOLD", Settings.negative_threshold))
35
+
36
+ return Settings(model_name=model_name, negative_threshold=negative_threshold)
37
+
app/main.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastAPI application entry point for the low-latency moderation API.
3
+ """
4
+
5
+ from fastapi import FastAPI, HTTPException
6
+
7
+ from .config import get_settings
8
+ from .models import MessagePayload, ModerationResponse
9
+ from . import services
10
+
11
+ settings = get_settings()
12
+
13
+ app = FastAPI(
14
+ title=settings.api_title,
15
+ description=settings.api_description,
16
+ version=settings.api_version,
17
+ )
18
+
19
+
20
+ @app.on_event("startup")
21
+ async def warm_model_cache() -> None:
22
+ """Load the toxicity classification pipeline during startup to avoid request-time latency."""
23
+
24
+ services.get_toxicity_pipeline()
25
+
26
+
27
+ @app.post("/api/check-message", response_model=ModerationResponse)
28
+ async def check_message(payload: MessagePayload) -> ModerationResponse:
29
+ """
30
+ Classify the incoming text and block highly confident toxic content.
31
+ """
32
+
33
+ result = services.analyze_text(payload.text)
34
+
35
+ if result.label == "TOXIC" and result.confidence >= settings.negative_threshold:
36
+ return ModerationResponse(
37
+ status="rejected",
38
+ message="Message classified as toxic with high confidence.",
39
+ label=result.label,
40
+ confidence=result.confidence,
41
+ )
42
+
43
+
44
+ return ModerationResponse(
45
+ status= "rejected" if result.label == "TOXIC" else "accepted",
46
+ message="Message passed moderation.",
47
+ label=result.label,
48
+ confidence=result.confidence,
49
+ )
50
+
app/models.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pydantic models shared across the API.
3
+ """
4
+
5
+ from typing import Literal
6
+
7
+ from pydantic import BaseModel, Field
8
+
9
+
10
+ class MessagePayload(BaseModel):
11
+ """Incoming chat message payload."""
12
+
13
+ text: str = Field(..., min_length=1, description="User-provided message to moderate.")
14
+
15
+
16
+ class ToxicityResult(BaseModel):
17
+ """Normalized output from the toxicity classification model."""
18
+
19
+ label: Literal["TOXIC", "NON-TOXIC"]
20
+ confidence: float = Field(..., ge=0.0, le=1.0)
21
+
22
+
23
+ class ModerationResponse(BaseModel):
24
+ """API response returned to the chat application."""
25
+
26
+ status: Literal["accepted", "rejected"]
27
+ message: str
28
+ label: Literal["TOXIC", "NON-TOXIC"]
29
+ confidence: float
30
+
app/services.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Service layer for interacting with the Hugging Face toxic comment classification pipeline.
3
+ """
4
+
5
+ from functools import lru_cache
6
+ from typing import Callable, List, TypedDict
7
+
8
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, TextClassificationPipeline
9
+
10
+ from .config import get_settings
11
+ from .models import ToxicityResult
12
+
13
+
14
+ class PipelineOutput(TypedDict):
15
+ label: str
16
+ score: float
17
+
18
+
19
+ def _build_pipeline() -> Callable[[str], List[PipelineOutput]]:
20
+ """
21
+ Instantiate the Hugging Face TextClassificationPipeline for toxic comment detection.
22
+ Returns a callable that accepts text and yields label/score dictionaries.
23
+ """
24
+
25
+ settings = get_settings()
26
+ tokenizer = AutoTokenizer.from_pretrained(settings.model_name)
27
+ model = AutoModelForSequenceClassification.from_pretrained(settings.model_name)
28
+ return TextClassificationPipeline(model=model, tokenizer=tokenizer)
29
+
30
+
31
+ @lru_cache(maxsize=1)
32
+ def get_toxicity_pipeline() -> Callable[[str], List[PipelineOutput]]:
33
+ """Return a cached instance of the toxicity classification pipeline."""
34
+
35
+ return _build_pipeline()
36
+
37
+
38
+ def analyze_text(text: str) -> ToxicityResult:
39
+ """
40
+ Run toxicity classification on text and normalize the response.
41
+ The model outputs labels that we normalize to TOXIC/NON-TOXIC.
42
+ """
43
+
44
+ predictor = get_toxicity_pipeline()
45
+ result = predictor(text)[0]
46
+ raw_label = result["label"].upper()
47
+ confidence = float(result["score"])
48
+ # print(f"{raw_label}....raw_label, {result}")
49
+
50
+ return ToxicityResult(label=raw_label, confidence=confidence)
51
+
check_model_size.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Script to check the memory usage of the loaded toxicity classification model.
3
+ """
4
+
5
+ import os
6
+ import psutil
7
+ import torch
8
+
9
+ from app.services import get_toxicity_pipeline
10
+
11
+
12
+ def format_bytes(bytes_value: int) -> str:
13
+ """Convert bytes to human-readable format."""
14
+ for unit in ['B', 'KB', 'MB', 'GB']:
15
+ if bytes_value < 1024.0:
16
+ return f"{bytes_value:.2f} {unit}"
17
+ bytes_value /= 1024.0
18
+ return f"{bytes_value:.2f} TB"
19
+
20
+
21
+ def main():
22
+ """Load the model and display memory usage statistics."""
23
+
24
+ process = psutil.Process(os.getpid())
25
+
26
+ # Get baseline memory before loading
27
+ baseline_memory = process.memory_info().rss
28
+ print(f"Baseline RAM usage: {format_bytes(baseline_memory)}")
29
+ print()
30
+
31
+ print("Loading model...")
32
+ get_toxicity_pipeline() # Forces cache to load it
33
+ print("Model loaded!")
34
+ print()
35
+
36
+ # Get memory after loading
37
+ final_memory = process.memory_info().rss
38
+ model_memory = final_memory - baseline_memory
39
+
40
+ print("=" * 50)
41
+ print("Memory Statistics:")
42
+ print("=" * 50)
43
+ print(f"Baseline RAM: {format_bytes(baseline_memory)}")
44
+ print(f"Final RAM: {format_bytes(final_memory)}")
45
+ print(f"Model RAM: {format_bytes(model_memory)}")
46
+ print()
47
+
48
+ # Additional system info
49
+ print("System Information:")
50
+ print("=" * 50)
51
+ total_memory = psutil.virtual_memory().total
52
+ available_memory = psutil.virtual_memory().available
53
+ print(f"Total RAM: {format_bytes(total_memory)}")
54
+ print(f"Available RAM: {format_bytes(available_memory)}")
55
+ print(f"RAM Used: {format_bytes(total_memory - available_memory)}")
56
+ print()
57
+
58
+ # PyTorch GPU info if available
59
+ if torch.cuda.is_available():
60
+ print("GPU Information:")
61
+ print("=" * 50)
62
+ for i in range(torch.cuda.device_count()):
63
+ gpu_memory = torch.cuda.get_device_properties(i).total_memory
64
+ print(f"GPU {i}: {format_bytes(gpu_memory)}")
65
+ else:
66
+ print("GPU: Not available (using CPU)")
67
+
68
+
69
+ if __name__ == "__main__":
70
+ main()
71
+
main.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from app.main import app
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.111.0
2
+ uvicorn[standard]==0.30.1
3
+ transformers==4.44.0
4
+ torch==2.3.1
5
+ pydantic==2.7.1
6
+ pytest==8.3.2
7
+ psutil==5.9.8
8
+
tests/test_api.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Lightweight tests for the moderation endpoint.
3
+ """
4
+
5
+ from fastapi.testclient import TestClient
6
+ import pytest
7
+
8
+ from app.main import app
9
+ from app import services
10
+ from app.models import SentimentResult
11
+
12
+ client = TestClient(app)
13
+
14
+
15
+ @pytest.fixture(autouse=True)
16
+ def clear_cache():
17
+ """Ensure per-test isolation for cached pipeline calls."""
18
+
19
+ services.get_sentiment_pipeline.cache_clear() # type: ignore[attr-defined]
20
+ yield
21
+ services.get_sentiment_pipeline.cache_clear() # type: ignore[attr-defined]
22
+
23
+
24
+ def test_rejects_high_confidence_negative(monkeypatch):
25
+ """Requests should be blocked when the model is confident a message is negative."""
26
+
27
+ monkeypatch.setattr(
28
+ services,
29
+ "analyze_text",
30
+ lambda _text: SentimentResult(label="NEGATIVE", confidence=0.93),
31
+ )
32
+
33
+ response = client.post("/api/check-message", json={"text": "You are awful."})
34
+ assert response.status_code == 400
35
+ detail = response.json()
36
+ assert detail["status"] == "rejected"
37
+ assert detail["label"] == "NEGATIVE"
38
+
39
+
40
+ def test_accepts_positive_or_low_confidence(monkeypatch):
41
+ """Requests should succeed when the message is allowed."""
42
+
43
+ monkeypatch.setattr(
44
+ services,
45
+ "analyze_text",
46
+ lambda _text: SentimentResult(label="POSITIVE", confidence=0.52),
47
+ )
48
+
49
+ response = client.post("/api/check-message", json={"text": "Great job!"})
50
+ payload = response.json()
51
+
52
+ assert response.status_code == 200
53
+ assert payload["status"] == "accepted"
54
+ assert payload["label"] == "POSITIVE"
55
+