Spaces:
Sleeping
Sleeping
Commit ·
2571402
1
Parent(s): 2d0ef3b
[REF] api documentation
Browse files- .env.example +1 -0
- README.md +7 -0
- app/core/config.py +1 -0
- app/routers/classification.py +1 -1
- app/schemas/classification.py +8 -1
- app/services/classifier_service.py +61 -3
- app/services/label_service.py +3 -0
- docs/reference/api.md +25 -1
- docs/reference/configuration.md +2 -0
- docs/tutorials/getting-started.md +1 -1
- tests/test_classifier_service.py +115 -2
- tests/test_routes.py +31 -1
.env.example
CHANGED
|
@@ -8,5 +8,6 @@ UPLOAD_SUBDIR=uploads
|
|
| 8 |
CLASSIFIER_MODEL=AyoubChLin/bert-base-uncased-zeroshot-nli
|
| 9 |
ENABLE_MODEL_QUANTIZATION=true
|
| 10 |
HUGGINGFACE_TOKEN=
|
|
|
|
| 11 |
|
| 12 |
DEFAULT_LABELS_CSV=news,sport,finance,politics
|
|
|
|
| 8 |
CLASSIFIER_MODEL=AyoubChLin/bert-base-uncased-zeroshot-nli
|
| 9 |
ENABLE_MODEL_QUANTIZATION=true
|
| 10 |
HUGGINGFACE_TOKEN=
|
| 11 |
+
CLASSIFIER_ENTAILMENT_LABEL_ID=
|
| 12 |
|
| 13 |
DEFAULT_LABELS_CSV=news,sport,finance,politics
|
README.md
CHANGED
|
@@ -27,6 +27,12 @@ Refactored into a modular FastAPI backend with clear layers:
|
|
| 27 |
- `POST /configlabel` -> returns labels array
|
| 28 |
- `GET /labels` -> returns labels array
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
Additional operational endpoints:
|
| 31 |
- `GET /health/liveness`
|
| 32 |
- `GET /health/readiness`
|
|
@@ -42,6 +48,7 @@ Key vars:
|
|
| 42 |
- `CLASSIFIER_MODEL`
|
| 43 |
- `ENABLE_MODEL_QUANTIZATION`
|
| 44 |
- `HUGGINGFACE_TOKEN`
|
|
|
|
| 45 |
- `DEFAULT_LABELS_CSV`
|
| 46 |
|
| 47 |
## Local Run
|
|
|
|
| 27 |
- `POST /configlabel` -> returns labels array
|
| 28 |
- `GET /labels` -> returns labels array
|
| 29 |
|
| 30 |
+
`POST /configlabel` exact payload:
|
| 31 |
+
- body accepts `{"labels":["label1","label2","label3"]}`
|
| 32 |
+
- all resulting labels are trimmed, empty values removed
|
| 33 |
+
- duplicates are kept if they are provided
|
| 34 |
+
- returns the stored `string[]` labels
|
| 35 |
+
|
| 36 |
Additional operational endpoints:
|
| 37 |
- `GET /health/liveness`
|
| 38 |
- `GET /health/readiness`
|
|
|
|
| 48 |
- `CLASSIFIER_MODEL`
|
| 49 |
- `ENABLE_MODEL_QUANTIZATION`
|
| 50 |
- `HUGGINGFACE_TOKEN`
|
| 51 |
+
- `CLASSIFIER_ENTAILMENT_LABEL_ID` (optional override when model config has no entailment label name)
|
| 52 |
- `DEFAULT_LABELS_CSV`
|
| 53 |
|
| 54 |
## Local Run
|
app/core/config.py
CHANGED
|
@@ -18,6 +18,7 @@ class Settings(BaseSettings):
|
|
| 18 |
classifier_model: str = "AyoubChLin/bert-base-uncased-zeroshot-nli"
|
| 19 |
enable_model_quantization: bool = True
|
| 20 |
huggingface_token: str | None = None
|
|
|
|
| 21 |
|
| 22 |
default_labels_csv: str = Field(default="news,sport,finance,politics")
|
| 23 |
|
|
|
|
| 18 |
classifier_model: str = "AyoubChLin/bert-base-uncased-zeroshot-nli"
|
| 19 |
enable_model_quantization: bool = True
|
| 20 |
huggingface_token: str | None = None
|
| 21 |
+
classifier_entailment_label_id: int | None = None
|
| 22 |
|
| 23 |
default_labels_csv: str = Field(default="news,sport,finance,politics")
|
| 24 |
|
app/routers/classification.py
CHANGED
|
@@ -56,7 +56,7 @@ async def classify_uploaded_file(file: UploadFile = File(...)) -> dict:
|
|
| 56 |
|
| 57 |
@router.post("/configlabel", response_model=list[str])
|
| 58 |
async def configure_labels(payload: LabelUpdateInput) -> list[str]:
|
| 59 |
-
labels = label_service.
|
| 60 |
if not labels:
|
| 61 |
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="At least one label is required")
|
| 62 |
return labels
|
|
|
|
| 56 |
|
| 57 |
@router.post("/configlabel", response_model=list[str])
|
| 58 |
async def configure_labels(payload: LabelUpdateInput) -> list[str]:
|
| 59 |
+
labels = label_service.set_labels(payload.get_normalized_labels())
|
| 60 |
if not labels:
|
| 61 |
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="At least one label is required")
|
| 62 |
return labels
|
app/schemas/classification.py
CHANGED
|
@@ -10,7 +10,14 @@ class TextInput(BaseSchema):
|
|
| 10 |
|
| 11 |
|
| 12 |
class LabelUpdateInput(BaseSchema):
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
class ClassifierResponse(BaseSchema):
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
class LabelUpdateInput(BaseSchema):
|
| 13 |
+
labels: list[str] = Field(
|
| 14 |
+
min_length=1,
|
| 15 |
+
description="Direct list of labels. Items are trimmed and empty values are removed.",
|
| 16 |
+
examples=[["news", "sport", "finance"]],
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
def get_normalized_labels(self) -> list[str]:
|
| 20 |
+
return [label.strip() for label in self.labels if isinstance(label, str) and label.strip()]
|
| 21 |
|
| 22 |
|
| 23 |
class ClassifierResponse(BaseSchema):
|
app/services/classifier_service.py
CHANGED
|
@@ -61,19 +61,77 @@ class ClassifierService:
|
|
| 61 |
cleaned = [label.strip() for label in labels if isinstance(label, str) and label.strip()]
|
| 62 |
return list(dict.fromkeys(cleaned))
|
| 63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
@staticmethod
|
| 65 |
def _resolve_entailment_id(model: Any) -> int:
|
| 66 |
label2id = getattr(model.config, "label2id", {}) or {}
|
| 67 |
for label, label_id in label2id.items():
|
| 68 |
if isinstance(label, str) and label.lower().startswith("entail"):
|
| 69 |
-
|
|
|
|
|
|
|
| 70 |
|
| 71 |
id2label = getattr(model.config, "id2label", {}) or {}
|
| 72 |
for label_id, label in id2label.items():
|
| 73 |
if isinstance(label, str) and label.lower().startswith("entail"):
|
| 74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
-
raise ClassificationError(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
def classify(self, text: str, labels: list[str]) -> str:
|
| 79 |
candidate_labels = self._normalize_labels(labels)
|
|
|
|
| 61 |
cleaned = [label.strip() for label in labels if isinstance(label, str) and label.strip()]
|
| 62 |
return list(dict.fromkeys(cleaned))
|
| 63 |
|
| 64 |
+
@staticmethod
|
| 65 |
+
def _parse_label_id(value: Any) -> int | None:
|
| 66 |
+
try:
|
| 67 |
+
return int(value)
|
| 68 |
+
except (TypeError, ValueError):
|
| 69 |
+
return None
|
| 70 |
+
|
| 71 |
+
@staticmethod
|
| 72 |
+
def _extract_task_specific_entailment_id(model: Any) -> int | None:
|
| 73 |
+
task_specific_params = getattr(model.config, "task_specific_params", {}) or {}
|
| 74 |
+
if not isinstance(task_specific_params, dict):
|
| 75 |
+
return None
|
| 76 |
+
|
| 77 |
+
zero_shot_params = task_specific_params.get("zero-shot-classification", {})
|
| 78 |
+
if not isinstance(zero_shot_params, dict):
|
| 79 |
+
return None
|
| 80 |
+
|
| 81 |
+
return ClassifierService._parse_label_id(zero_shot_params.get("entailment_id"))
|
| 82 |
+
|
| 83 |
+
@staticmethod
|
| 84 |
+
def _has_generic_label_names(model: Any) -> bool:
|
| 85 |
+
label2id = getattr(model.config, "label2id", {}) or {}
|
| 86 |
+
id2label = getattr(model.config, "id2label", {}) or {}
|
| 87 |
+
|
| 88 |
+
labels: list[str] = []
|
| 89 |
+
labels.extend(label for label in label2id.keys() if isinstance(label, str))
|
| 90 |
+
labels.extend(label for label in id2label.values() if isinstance(label, str))
|
| 91 |
+
if not labels:
|
| 92 |
+
return False
|
| 93 |
+
|
| 94 |
+
return all(label.lower().startswith("label_") for label in labels)
|
| 95 |
+
|
| 96 |
@staticmethod
|
| 97 |
def _resolve_entailment_id(model: Any) -> int:
|
| 98 |
label2id = getattr(model.config, "label2id", {}) or {}
|
| 99 |
for label, label_id in label2id.items():
|
| 100 |
if isinstance(label, str) and label.lower().startswith("entail"):
|
| 101 |
+
parsed = ClassifierService._parse_label_id(label_id)
|
| 102 |
+
if parsed is not None:
|
| 103 |
+
return parsed
|
| 104 |
|
| 105 |
id2label = getattr(model.config, "id2label", {}) or {}
|
| 106 |
for label_id, label in id2label.items():
|
| 107 |
if isinstance(label, str) and label.lower().startswith("entail"):
|
| 108 |
+
parsed = ClassifierService._parse_label_id(label_id)
|
| 109 |
+
if parsed is not None:
|
| 110 |
+
return parsed
|
| 111 |
+
|
| 112 |
+
task_specific_entailment_id = ClassifierService._extract_task_specific_entailment_id(model)
|
| 113 |
+
if task_specific_entailment_id is not None:
|
| 114 |
+
return task_specific_entailment_id
|
| 115 |
+
|
| 116 |
+
if settings.classifier_entailment_label_id is not None:
|
| 117 |
+
return settings.classifier_entailment_label_id
|
| 118 |
+
|
| 119 |
+
num_labels = ClassifierService._parse_label_id(getattr(model.config, "num_labels", None))
|
| 120 |
+
if num_labels == 3 and (
|
| 121 |
+
ClassifierService._has_generic_label_names(model) or (not label2id and not id2label)
|
| 122 |
+
):
|
| 123 |
+
logger.warning(
|
| 124 |
+
"Falling back to entailment label id 2 because model config labels are generic or missing "
|
| 125 |
+
"and no explicit entailment mapping was found. Set CLASSIFIER_ENTAILMENT_LABEL_ID "
|
| 126 |
+
"to override this behavior."
|
| 127 |
+
)
|
| 128 |
+
return 2
|
| 129 |
|
| 130 |
+
raise ClassificationError(
|
| 131 |
+
"Classifier model is missing an entailment label mapping. "
|
| 132 |
+
"Set CLASSIFIER_ENTAILMENT_LABEL_ID in the environment when the model config "
|
| 133 |
+
"does not expose an entailment label."
|
| 134 |
+
)
|
| 135 |
|
| 136 |
def classify(self, text: str, labels: list[str]) -> str:
|
| 137 |
candidate_labels = self._normalize_labels(labels)
|
app/services/label_service.py
CHANGED
|
@@ -14,5 +14,8 @@ class LabelService:
|
|
| 14 |
labels = [label.strip() for label in labels_csv.split(",") if label.strip()]
|
| 15 |
return self._config.set_labels(labels)
|
| 16 |
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
label_service = LabelService()
|
|
|
|
| 14 |
labels = [label.strip() for label in labels_csv.split(",") if label.strip()]
|
| 15 |
return self._config.set_labels(labels)
|
| 16 |
|
| 17 |
+
def set_labels(self, labels: list[str]) -> list[str]:
|
| 18 |
+
return self._config.set_labels(labels)
|
| 19 |
+
|
| 20 |
|
| 21 |
label_service = LabelService()
|
docs/reference/api.md
CHANGED
|
@@ -18,9 +18,33 @@ Evidence:
|
|
| 18 |
| POST | `/api/language` | `{text}` | `"<language>"` |
|
| 19 |
| POST | `/api/transformer` | multipart `file` | `{filename, content}` |
|
| 20 |
| POST | `/classify` | multipart `file` | `{label, language, type?}` |
|
| 21 |
-
| POST | `/configlabel` | `{
|
| 22 |
| GET | `/labels` | none | `string[]` |
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
## Validation and errors
|
| 25 |
- `400` for input validation and extraction problems.
|
| 26 |
- `502` for classifier/language inference failures.
|
|
|
|
| 18 |
| POST | `/api/language` | `{text}` | `"<language>"` |
|
| 19 |
| POST | `/api/transformer` | multipart `file` | `{filename, content}` |
|
| 20 |
| POST | `/classify` | multipart `file` | `{label, language, type?}` |
|
| 21 |
+
| POST | `/configlabel` | `{labels:["a","b"]}` | `string[]` |
|
| 22 |
| GET | `/labels` | none | `string[]` |
|
| 23 |
|
| 24 |
+
## `POST /configlabel` contract (exact)
|
| 25 |
+
|
| 26 |
+
Request body:
|
| 27 |
+
- JSON object with required `labels` array.
|
| 28 |
+
- Example: `{"labels":["tech","health","legal"]}`
|
| 29 |
+
|
| 30 |
+
Normalization behavior:
|
| 31 |
+
- Trim whitespace for each label.
|
| 32 |
+
- Remove empty entries.
|
| 33 |
+
- Preserve order.
|
| 34 |
+
- Keep duplicates as provided.
|
| 35 |
+
|
| 36 |
+
Response:
|
| 37 |
+
- `200` with `string[]`, the stored labels after normalization.
|
| 38 |
+
- Example response: `["tech", "health", "legal"]`
|
| 39 |
+
|
| 40 |
+
Errors:
|
| 41 |
+
- `400` with detail `"At least one label is required"` when all parsed labels are empty.
|
| 42 |
+
- `422` when `labels` is missing or unknown fields are provided (schema validation error).
|
| 43 |
+
|
| 44 |
+
Related state behavior:
|
| 45 |
+
- Labels are process-local in memory and reset on restart.
|
| 46 |
+
- `GET /labels` returns the same current in-memory list.
|
| 47 |
+
|
| 48 |
## Validation and errors
|
| 49 |
- `400` for input validation and extraction problems.
|
| 50 |
- `502` for classifier/language inference failures.
|
docs/reference/configuration.md
CHANGED
|
@@ -29,6 +29,7 @@ Evidence:
|
|
| 29 |
| `CLASSIFIER_MODEL` | `AyoubChLin/bert-base-uncased-zeroshot-nli` | Hugging Face model ID used for local zero-shot NLI classification |
|
| 30 |
| `ENABLE_MODEL_QUANTIZATION` | `true` | enable dynamic INT8 quantization with automatic fallback |
|
| 31 |
| `HUGGINGFACE_TOKEN` | empty | optional auth token for client init |
|
|
|
|
| 32 |
|
| 33 |
## Language detector settings
|
| 34 |
|
|
@@ -45,6 +46,7 @@ Evidence:
|
|
| 45 |
## Behavior notes
|
| 46 |
- Labels are process-local in memory and reset on restart.
|
| 47 |
- Upload directory is auto-created at app startup.
|
|
|
|
| 48 |
|
| 49 |
Evidence:
|
| 50 |
- `app/services/label_service.py`
|
|
|
|
| 29 |
| `CLASSIFIER_MODEL` | `AyoubChLin/bert-base-uncased-zeroshot-nli` | Hugging Face model ID used for local zero-shot NLI classification |
|
| 30 |
| `ENABLE_MODEL_QUANTIZATION` | `true` | enable dynamic INT8 quantization with automatic fallback |
|
| 31 |
| `HUGGINGFACE_TOKEN` | empty | optional auth token for client init |
|
| 32 |
+
| `CLASSIFIER_ENTAILMENT_LABEL_ID` | empty | optional integer override for entailment logit index when model config does not expose an `entailment` label |
|
| 33 |
|
| 34 |
## Language detector settings
|
| 35 |
|
|
|
|
| 46 |
## Behavior notes
|
| 47 |
- Labels are process-local in memory and reset on restart.
|
| 48 |
- Upload directory is auto-created at app startup.
|
| 49 |
+
- If `label2id`/`id2label` does not include an entailment label, the service checks task-specific config, then `CLASSIFIER_ENTAILMENT_LABEL_ID`, then falls back to index `2` for 3-logit generic/missing mappings.
|
| 50 |
|
| 51 |
Evidence:
|
| 52 |
- `app/services/label_service.py`
|
docs/tutorials/getting-started.md
CHANGED
|
@@ -60,7 +60,7 @@ Evidence:
|
|
| 60 |
```bash
|
| 61 |
curl -s -X POST http://localhost:4002/configlabel \
|
| 62 |
-H 'content-type: application/json' \
|
| 63 |
-
-d '{"
|
| 64 |
|
| 65 |
curl -s http://localhost:4002/labels
|
| 66 |
```
|
|
|
|
| 60 |
```bash
|
| 61 |
curl -s -X POST http://localhost:4002/configlabel \
|
| 62 |
-H 'content-type: application/json' \
|
| 63 |
+
-d '{"labels":["tech","health","legal"]}'
|
| 64 |
|
| 65 |
curl -s http://localhost:4002/labels
|
| 66 |
```
|
tests/test_classifier_service.py
CHANGED
|
@@ -15,9 +15,9 @@ class _FakeTokenizer:
|
|
| 15 |
|
| 16 |
|
| 17 |
class _FakeInferenceModel:
|
| 18 |
-
def __init__(self, logits: torch.Tensor) -> None:
|
| 19 |
self._logits = logits
|
| 20 |
-
self.config = SimpleNamespace(
|
| 21 |
label2id={"CONTRADICTION": 0, "ENTAILMENT": 1},
|
| 22 |
id2label={0: "CONTRADICTION", 1: "ENTAILMENT"},
|
| 23 |
)
|
|
@@ -63,6 +63,119 @@ def test_classify_uses_runtime_candidate_labels(monkeypatch):
|
|
| 63 |
assert predicted == "sport"
|
| 64 |
|
| 65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
def test_model_quantization_falls_back_to_non_quantized_model(monkeypatch):
|
| 67 |
service = classifier_module.ClassifierService()
|
| 68 |
fake_model = _FakeLoadModel()
|
|
|
|
| 15 |
|
| 16 |
|
| 17 |
class _FakeInferenceModel:
|
| 18 |
+
def __init__(self, logits: torch.Tensor, config: SimpleNamespace | None = None) -> None:
|
| 19 |
self._logits = logits
|
| 20 |
+
self.config = config or SimpleNamespace(
|
| 21 |
label2id={"CONTRADICTION": 0, "ENTAILMENT": 1},
|
| 22 |
id2label={0: "CONTRADICTION", 1: "ENTAILMENT"},
|
| 23 |
)
|
|
|
|
| 63 |
assert predicted == "sport"
|
| 64 |
|
| 65 |
|
| 66 |
+
def test_classify_uses_task_specific_entailment_id_when_label_names_are_generic(monkeypatch):
|
| 67 |
+
service = classifier_module.ClassifierService()
|
| 68 |
+
tokenizer = _FakeTokenizer()
|
| 69 |
+
model = _FakeInferenceModel(
|
| 70 |
+
logits=torch.tensor(
|
| 71 |
+
[
|
| 72 |
+
[1.8, 0.3, 0.4], # finance -> low entailment
|
| 73 |
+
[0.4, 0.7, 3.7], # sport -> highest entailment
|
| 74 |
+
]
|
| 75 |
+
),
|
| 76 |
+
config=SimpleNamespace(
|
| 77 |
+
label2id={"LABEL_0": 0, "LABEL_1": 1, "LABEL_2": 2},
|
| 78 |
+
id2label={0: "LABEL_0", 1: "LABEL_1", 2: "LABEL_2"},
|
| 79 |
+
task_specific_params={"zero-shot-classification": {"entailment_id": 2}},
|
| 80 |
+
num_labels=3,
|
| 81 |
+
),
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
monkeypatch.setattr(service, "_load_model", lambda: (tokenizer, model))
|
| 85 |
+
monkeypatch.setattr(classifier_module.settings, "classifier_entailment_label_id", None)
|
| 86 |
+
|
| 87 |
+
predicted = service.classify(
|
| 88 |
+
"The story is mostly about football transfers.",
|
| 89 |
+
["finance", "sport"],
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
assert predicted == "sport"
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def test_classify_uses_explicit_entailment_id_setting_when_mapping_is_missing(monkeypatch):
|
| 96 |
+
service = classifier_module.ClassifierService()
|
| 97 |
+
tokenizer = _FakeTokenizer()
|
| 98 |
+
model = _FakeInferenceModel(
|
| 99 |
+
logits=torch.tensor(
|
| 100 |
+
[
|
| 101 |
+
[2.0, 0.3], # finance -> low entailment
|
| 102 |
+
[0.2, 3.4], # sport -> highest entailment
|
| 103 |
+
]
|
| 104 |
+
),
|
| 105 |
+
config=SimpleNamespace(
|
| 106 |
+
label2id={"NEGATIVE": 0, "POSITIVE": 1},
|
| 107 |
+
id2label={0: "NEGATIVE", 1: "POSITIVE"},
|
| 108 |
+
num_labels=2,
|
| 109 |
+
),
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
monkeypatch.setattr(service, "_load_model", lambda: (tokenizer, model))
|
| 113 |
+
monkeypatch.setattr(classifier_module.settings, "classifier_entailment_label_id", 1)
|
| 114 |
+
|
| 115 |
+
predicted = service.classify(
|
| 116 |
+
"The story is mostly about football transfers.",
|
| 117 |
+
["finance", "sport"],
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
assert predicted == "sport"
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def test_classify_falls_back_to_mnli_entailment_index_for_generic_three_label_configs(monkeypatch):
|
| 124 |
+
service = classifier_module.ClassifierService()
|
| 125 |
+
tokenizer = _FakeTokenizer()
|
| 126 |
+
model = _FakeInferenceModel(
|
| 127 |
+
logits=torch.tensor(
|
| 128 |
+
[
|
| 129 |
+
[2.3, 0.6, 0.8], # finance -> low entailment
|
| 130 |
+
[0.4, 0.8, 3.9], # sport -> highest entailment
|
| 131 |
+
]
|
| 132 |
+
),
|
| 133 |
+
config=SimpleNamespace(
|
| 134 |
+
label2id={"LABEL_0": 0, "LABEL_1": 1, "LABEL_2": 2},
|
| 135 |
+
id2label={0: "LABEL_0", 1: "LABEL_1", 2: "LABEL_2"},
|
| 136 |
+
num_labels=3,
|
| 137 |
+
),
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
monkeypatch.setattr(service, "_load_model", lambda: (tokenizer, model))
|
| 141 |
+
monkeypatch.setattr(classifier_module.settings, "classifier_entailment_label_id", None)
|
| 142 |
+
|
| 143 |
+
predicted = service.classify(
|
| 144 |
+
"The story is mostly about football transfers.",
|
| 145 |
+
["finance", "sport"],
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
assert predicted == "sport"
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def test_classify_falls_back_to_mnli_entailment_index_for_missing_label_mapping(monkeypatch):
|
| 152 |
+
service = classifier_module.ClassifierService()
|
| 153 |
+
tokenizer = _FakeTokenizer()
|
| 154 |
+
model = _FakeInferenceModel(
|
| 155 |
+
logits=torch.tensor(
|
| 156 |
+
[
|
| 157 |
+
[1.8, 0.4, 0.5], # finance -> low entailment
|
| 158 |
+
[0.2, 0.5, 3.6], # sport -> highest entailment
|
| 159 |
+
]
|
| 160 |
+
),
|
| 161 |
+
config=SimpleNamespace(
|
| 162 |
+
label2id={},
|
| 163 |
+
id2label={},
|
| 164 |
+
num_labels=3,
|
| 165 |
+
),
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
monkeypatch.setattr(service, "_load_model", lambda: (tokenizer, model))
|
| 169 |
+
monkeypatch.setattr(classifier_module.settings, "classifier_entailment_label_id", None)
|
| 170 |
+
|
| 171 |
+
predicted = service.classify(
|
| 172 |
+
"The story is mostly about football transfers.",
|
| 173 |
+
["finance", "sport"],
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
assert predicted == "sport"
|
| 177 |
+
|
| 178 |
+
|
| 179 |
def test_model_quantization_falls_back_to_non_quantized_model(monkeypatch):
|
| 180 |
service = classifier_module.ClassifierService()
|
| 181 |
fake_model = _FakeLoadModel()
|
tests/test_routes.py
CHANGED
|
@@ -27,7 +27,7 @@ def test_language_endpoint_contract(monkeypatch):
|
|
| 27 |
|
| 28 |
|
| 29 |
def test_labels_config_roundtrip():
|
| 30 |
-
response = client.post("/configlabel", json={"
|
| 31 |
assert response.status_code == 200
|
| 32 |
assert response.json() == ["tech", "health", "legal"]
|
| 33 |
|
|
@@ -36,6 +36,36 @@ def test_labels_config_roundtrip():
|
|
| 36 |
assert get_response.json() == ["tech", "health", "legal"]
|
| 37 |
|
| 38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
def test_transform_file_contract(monkeypatch):
|
| 40 |
monkeypatch.setattr(classification_pipeline, "transform_file", lambda filename, path: "extracted content")
|
| 41 |
|
|
|
|
| 27 |
|
| 28 |
|
| 29 |
def test_labels_config_roundtrip():
|
| 30 |
+
response = client.post("/configlabel", json={"labels": ["tech", "health", "legal"]})
|
| 31 |
assert response.status_code == 200
|
| 32 |
assert response.json() == ["tech", "health", "legal"]
|
| 33 |
|
|
|
|
| 36 |
assert get_response.json() == ["tech", "health", "legal"]
|
| 37 |
|
| 38 |
|
| 39 |
+
def test_labels_config_accepts_labels_list_payload():
|
| 40 |
+
response = client.post("/configlabel", json={"labels": ["tech", "health", "legal"]})
|
| 41 |
+
assert response.status_code == 200
|
| 42 |
+
assert response.json() == ["tech", "health", "legal"]
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def test_labels_config_rejects_empty_labels():
|
| 46 |
+
response = client.post("/configlabel", json={"labels": [" ", ""]})
|
| 47 |
+
assert response.status_code == 400
|
| 48 |
+
assert response.json() == {"detail": "At least one label is required"}
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def test_labels_config_rejects_missing_labels():
|
| 52 |
+
response = client.post("/configlabel", json={})
|
| 53 |
+
assert response.status_code == 422
|
| 54 |
+
assert "labels" in response.text
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def test_labels_config_rejects_text_field():
|
| 58 |
+
response = client.post("/configlabel", json={"text": "tech,health"})
|
| 59 |
+
assert response.status_code == 422
|
| 60 |
+
assert "extra_forbidden" in response.text
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def test_labels_config_rejects_texts_field():
|
| 64 |
+
response = client.post("/configlabel", json={"texts": ["tech,health"]})
|
| 65 |
+
assert response.status_code == 422
|
| 66 |
+
assert "extra_forbidden" in response.text
|
| 67 |
+
|
| 68 |
+
|
| 69 |
def test_transform_file_contract(monkeypatch):
|
| 70 |
monkeypatch.setattr(classification_pipeline, "transform_file", lambda filename, path: "extracted content")
|
| 71 |
|