AyoubChLin commited on
Commit
2571402
·
1 Parent(s): 2d0ef3b

[REF] api documentation

Browse files
.env.example CHANGED
@@ -8,5 +8,6 @@ UPLOAD_SUBDIR=uploads
8
  CLASSIFIER_MODEL=AyoubChLin/bert-base-uncased-zeroshot-nli
9
  ENABLE_MODEL_QUANTIZATION=true
10
  HUGGINGFACE_TOKEN=
 
11
 
12
  DEFAULT_LABELS_CSV=news,sport,finance,politics
 
8
  CLASSIFIER_MODEL=AyoubChLin/bert-base-uncased-zeroshot-nli
9
  ENABLE_MODEL_QUANTIZATION=true
10
  HUGGINGFACE_TOKEN=
11
+ CLASSIFIER_ENTAILMENT_LABEL_ID=
12
 
13
  DEFAULT_LABELS_CSV=news,sport,finance,politics
README.md CHANGED
@@ -27,6 +27,12 @@ Refactored into a modular FastAPI backend with clear layers:
27
  - `POST /configlabel` -> returns labels array
28
  - `GET /labels` -> returns labels array
29
 
 
 
 
 
 
 
30
  Additional operational endpoints:
31
  - `GET /health/liveness`
32
  - `GET /health/readiness`
@@ -42,6 +48,7 @@ Key vars:
42
  - `CLASSIFIER_MODEL`
43
  - `ENABLE_MODEL_QUANTIZATION`
44
  - `HUGGINGFACE_TOKEN`
 
45
  - `DEFAULT_LABELS_CSV`
46
 
47
  ## Local Run
 
27
  - `POST /configlabel` -> returns labels array
28
  - `GET /labels` -> returns labels array
29
 
30
+ `POST /configlabel` exact payload:
31
+ - body accepts `{"labels":["label1","label2","label3"]}`
32
+ - all resulting labels are trimmed, empty values removed
33
+ - duplicates are kept if they are provided
34
+ - returns the stored `string[]` labels
35
+
36
  Additional operational endpoints:
37
  - `GET /health/liveness`
38
  - `GET /health/readiness`
 
48
  - `CLASSIFIER_MODEL`
49
  - `ENABLE_MODEL_QUANTIZATION`
50
  - `HUGGINGFACE_TOKEN`
51
+ - `CLASSIFIER_ENTAILMENT_LABEL_ID` (optional override when model config has no entailment label name)
52
  - `DEFAULT_LABELS_CSV`
53
 
54
  ## Local Run
app/core/config.py CHANGED
@@ -18,6 +18,7 @@ class Settings(BaseSettings):
18
  classifier_model: str = "AyoubChLin/bert-base-uncased-zeroshot-nli"
19
  enable_model_quantization: bool = True
20
  huggingface_token: str | None = None
 
21
 
22
  default_labels_csv: str = Field(default="news,sport,finance,politics")
23
 
 
18
  classifier_model: str = "AyoubChLin/bert-base-uncased-zeroshot-nli"
19
  enable_model_quantization: bool = True
20
  huggingface_token: str | None = None
21
+ classifier_entailment_label_id: int | None = None
22
 
23
  default_labels_csv: str = Field(default="news,sport,finance,politics")
24
 
app/routers/classification.py CHANGED
@@ -56,7 +56,7 @@ async def classify_uploaded_file(file: UploadFile = File(...)) -> dict:
56
 
57
  @router.post("/configlabel", response_model=list[str])
58
  async def configure_labels(payload: LabelUpdateInput) -> list[str]:
59
- labels = label_service.set_labels_from_csv(payload.text)
60
  if not labels:
61
  raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="At least one label is required")
62
  return labels
 
56
 
57
  @router.post("/configlabel", response_model=list[str])
58
  async def configure_labels(payload: LabelUpdateInput) -> list[str]:
59
+ labels = label_service.set_labels(payload.get_normalized_labels())
60
  if not labels:
61
  raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="At least one label is required")
62
  return labels
app/schemas/classification.py CHANGED
@@ -10,7 +10,14 @@ class TextInput(BaseSchema):
10
 
11
 
12
  class LabelUpdateInput(BaseSchema):
13
- text: str = Field(min_length=1, description="Comma-separated labels, e.g. 'news, sport, finance'")
 
 
 
 
 
 
 
14
 
15
 
16
  class ClassifierResponse(BaseSchema):
 
10
 
11
 
12
  class LabelUpdateInput(BaseSchema):
13
+ labels: list[str] = Field(
14
+ min_length=1,
15
+ description="Direct list of labels. Items are trimmed and empty values are removed.",
16
+ examples=[["news", "sport", "finance"]],
17
+ )
18
+
19
+ def get_normalized_labels(self) -> list[str]:
20
+ return [label.strip() for label in self.labels if isinstance(label, str) and label.strip()]
21
 
22
 
23
  class ClassifierResponse(BaseSchema):
app/services/classifier_service.py CHANGED
@@ -61,19 +61,77 @@ class ClassifierService:
61
  cleaned = [label.strip() for label in labels if isinstance(label, str) and label.strip()]
62
  return list(dict.fromkeys(cleaned))
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  @staticmethod
65
  def _resolve_entailment_id(model: Any) -> int:
66
  label2id = getattr(model.config, "label2id", {}) or {}
67
  for label, label_id in label2id.items():
68
  if isinstance(label, str) and label.lower().startswith("entail"):
69
- return int(label_id)
 
 
70
 
71
  id2label = getattr(model.config, "id2label", {}) or {}
72
  for label_id, label in id2label.items():
73
  if isinstance(label, str) and label.lower().startswith("entail"):
74
- return int(label_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
- raise ClassificationError("Classifier model is missing an entailment label mapping")
 
 
 
 
77
 
78
  def classify(self, text: str, labels: list[str]) -> str:
79
  candidate_labels = self._normalize_labels(labels)
 
61
  cleaned = [label.strip() for label in labels if isinstance(label, str) and label.strip()]
62
  return list(dict.fromkeys(cleaned))
63
 
64
+ @staticmethod
65
+ def _parse_label_id(value: Any) -> int | None:
66
+ try:
67
+ return int(value)
68
+ except (TypeError, ValueError):
69
+ return None
70
+
71
+ @staticmethod
72
+ def _extract_task_specific_entailment_id(model: Any) -> int | None:
73
+ task_specific_params = getattr(model.config, "task_specific_params", {}) or {}
74
+ if not isinstance(task_specific_params, dict):
75
+ return None
76
+
77
+ zero_shot_params = task_specific_params.get("zero-shot-classification", {})
78
+ if not isinstance(zero_shot_params, dict):
79
+ return None
80
+
81
+ return ClassifierService._parse_label_id(zero_shot_params.get("entailment_id"))
82
+
83
+ @staticmethod
84
+ def _has_generic_label_names(model: Any) -> bool:
85
+ label2id = getattr(model.config, "label2id", {}) or {}
86
+ id2label = getattr(model.config, "id2label", {}) or {}
87
+
88
+ labels: list[str] = []
89
+ labels.extend(label for label in label2id.keys() if isinstance(label, str))
90
+ labels.extend(label for label in id2label.values() if isinstance(label, str))
91
+ if not labels:
92
+ return False
93
+
94
+ return all(label.lower().startswith("label_") for label in labels)
95
+
96
  @staticmethod
97
  def _resolve_entailment_id(model: Any) -> int:
98
  label2id = getattr(model.config, "label2id", {}) or {}
99
  for label, label_id in label2id.items():
100
  if isinstance(label, str) and label.lower().startswith("entail"):
101
+ parsed = ClassifierService._parse_label_id(label_id)
102
+ if parsed is not None:
103
+ return parsed
104
 
105
  id2label = getattr(model.config, "id2label", {}) or {}
106
  for label_id, label in id2label.items():
107
  if isinstance(label, str) and label.lower().startswith("entail"):
108
+ parsed = ClassifierService._parse_label_id(label_id)
109
+ if parsed is not None:
110
+ return parsed
111
+
112
+ task_specific_entailment_id = ClassifierService._extract_task_specific_entailment_id(model)
113
+ if task_specific_entailment_id is not None:
114
+ return task_specific_entailment_id
115
+
116
+ if settings.classifier_entailment_label_id is not None:
117
+ return settings.classifier_entailment_label_id
118
+
119
+ num_labels = ClassifierService._parse_label_id(getattr(model.config, "num_labels", None))
120
+ if num_labels == 3 and (
121
+ ClassifierService._has_generic_label_names(model) or (not label2id and not id2label)
122
+ ):
123
+ logger.warning(
124
+ "Falling back to entailment label id 2 because model config labels are generic or missing "
125
+ "and no explicit entailment mapping was found. Set CLASSIFIER_ENTAILMENT_LABEL_ID "
126
+ "to override this behavior."
127
+ )
128
+ return 2
129
 
130
+ raise ClassificationError(
131
+ "Classifier model is missing an entailment label mapping. "
132
+ "Set CLASSIFIER_ENTAILMENT_LABEL_ID in the environment when the model config "
133
+ "does not expose an entailment label."
134
+ )
135
 
136
  def classify(self, text: str, labels: list[str]) -> str:
137
  candidate_labels = self._normalize_labels(labels)
app/services/label_service.py CHANGED
@@ -14,5 +14,8 @@ class LabelService:
14
  labels = [label.strip() for label in labels_csv.split(",") if label.strip()]
15
  return self._config.set_labels(labels)
16
 
 
 
 
17
 
18
  label_service = LabelService()
 
14
  labels = [label.strip() for label in labels_csv.split(",") if label.strip()]
15
  return self._config.set_labels(labels)
16
 
17
+ def set_labels(self, labels: list[str]) -> list[str]:
18
+ return self._config.set_labels(labels)
19
+
20
 
21
  label_service = LabelService()
docs/reference/api.md CHANGED
@@ -18,9 +18,33 @@ Evidence:
18
  | POST | `/api/language` | `{text}` | `"<language>"` |
19
  | POST | `/api/transformer` | multipart `file` | `{filename, content}` |
20
  | POST | `/classify` | multipart `file` | `{label, language, type?}` |
21
- | POST | `/configlabel` | `{text: "csv,labels"}` | `string[]` |
22
  | GET | `/labels` | none | `string[]` |
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  ## Validation and errors
25
  - `400` for input validation and extraction problems.
26
  - `502` for classifier/language inference failures.
 
18
  | POST | `/api/language` | `{text}` | `"<language>"` |
19
  | POST | `/api/transformer` | multipart `file` | `{filename, content}` |
20
  | POST | `/classify` | multipart `file` | `{label, language, type?}` |
21
+ | POST | `/configlabel` | `{labels:["a","b"]}` | `string[]` |
22
  | GET | `/labels` | none | `string[]` |
23
 
24
+ ## `POST /configlabel` contract (exact)
25
+
26
+ Request body:
27
+ - JSON object with required `labels` array.
28
+ - Example: `{"labels":["tech","health","legal"]}`
29
+
30
+ Normalization behavior:
31
+ - Trim whitespace for each label.
32
+ - Remove empty entries.
33
+ - Preserve order.
34
+ - Keep duplicates as provided.
35
+
36
+ Response:
37
+ - `200` with `string[]`, the stored labels after normalization.
38
+ - Example response: `["tech", "health", "legal"]`
39
+
40
+ Errors:
41
+ - `400` with detail `"At least one label is required"` when all parsed labels are empty.
42
+ - `422` when `labels` is missing or unknown fields are provided (schema validation error).
43
+
44
+ Related state behavior:
45
+ - Labels are process-local in memory and reset on restart.
46
+ - `GET /labels` returns the same current in-memory list.
47
+
48
  ## Validation and errors
49
  - `400` for input validation and extraction problems.
50
  - `502` for classifier/language inference failures.
docs/reference/configuration.md CHANGED
@@ -29,6 +29,7 @@ Evidence:
29
  | `CLASSIFIER_MODEL` | `AyoubChLin/bert-base-uncased-zeroshot-nli` | Hugging Face model ID used for local zero-shot NLI classification |
30
  | `ENABLE_MODEL_QUANTIZATION` | `true` | enable dynamic INT8 quantization with automatic fallback |
31
  | `HUGGINGFACE_TOKEN` | empty | optional auth token for client init |
 
32
 
33
  ## Language detector settings
34
 
@@ -45,6 +46,7 @@ Evidence:
45
  ## Behavior notes
46
  - Labels are process-local in memory and reset on restart.
47
  - Upload directory is auto-created at app startup.
 
48
 
49
  Evidence:
50
  - `app/services/label_service.py`
 
29
  | `CLASSIFIER_MODEL` | `AyoubChLin/bert-base-uncased-zeroshot-nli` | Hugging Face model ID used for local zero-shot NLI classification |
30
  | `ENABLE_MODEL_QUANTIZATION` | `true` | enable dynamic INT8 quantization with automatic fallback |
31
  | `HUGGINGFACE_TOKEN` | empty | optional auth token for client init |
32
+ | `CLASSIFIER_ENTAILMENT_LABEL_ID` | empty | optional integer override for entailment logit index when model config does not expose an `entailment` label |
33
 
34
  ## Language detector settings
35
 
 
46
  ## Behavior notes
47
  - Labels are process-local in memory and reset on restart.
48
  - Upload directory is auto-created at app startup.
49
+ - If `label2id`/`id2label` does not include an entailment label, the service checks task-specific config, then `CLASSIFIER_ENTAILMENT_LABEL_ID`, then falls back to index `2` for 3-logit generic/missing mappings.
50
 
51
  Evidence:
52
  - `app/services/label_service.py`
docs/tutorials/getting-started.md CHANGED
@@ -60,7 +60,7 @@ Evidence:
60
  ```bash
61
  curl -s -X POST http://localhost:4002/configlabel \
62
  -H 'content-type: application/json' \
63
- -d '{"text":"tech,health,legal"}'
64
 
65
  curl -s http://localhost:4002/labels
66
  ```
 
60
  ```bash
61
  curl -s -X POST http://localhost:4002/configlabel \
62
  -H 'content-type: application/json' \
63
+ -d '{"labels":["tech","health","legal"]}'
64
 
65
  curl -s http://localhost:4002/labels
66
  ```
tests/test_classifier_service.py CHANGED
@@ -15,9 +15,9 @@ class _FakeTokenizer:
15
 
16
 
17
  class _FakeInferenceModel:
18
- def __init__(self, logits: torch.Tensor) -> None:
19
  self._logits = logits
20
- self.config = SimpleNamespace(
21
  label2id={"CONTRADICTION": 0, "ENTAILMENT": 1},
22
  id2label={0: "CONTRADICTION", 1: "ENTAILMENT"},
23
  )
@@ -63,6 +63,119 @@ def test_classify_uses_runtime_candidate_labels(monkeypatch):
63
  assert predicted == "sport"
64
 
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  def test_model_quantization_falls_back_to_non_quantized_model(monkeypatch):
67
  service = classifier_module.ClassifierService()
68
  fake_model = _FakeLoadModel()
 
15
 
16
 
17
  class _FakeInferenceModel:
18
+ def __init__(self, logits: torch.Tensor, config: SimpleNamespace | None = None) -> None:
19
  self._logits = logits
20
+ self.config = config or SimpleNamespace(
21
  label2id={"CONTRADICTION": 0, "ENTAILMENT": 1},
22
  id2label={0: "CONTRADICTION", 1: "ENTAILMENT"},
23
  )
 
63
  assert predicted == "sport"
64
 
65
 
66
+ def test_classify_uses_task_specific_entailment_id_when_label_names_are_generic(monkeypatch):
67
+ service = classifier_module.ClassifierService()
68
+ tokenizer = _FakeTokenizer()
69
+ model = _FakeInferenceModel(
70
+ logits=torch.tensor(
71
+ [
72
+ [1.8, 0.3, 0.4], # finance -> low entailment
73
+ [0.4, 0.7, 3.7], # sport -> highest entailment
74
+ ]
75
+ ),
76
+ config=SimpleNamespace(
77
+ label2id={"LABEL_0": 0, "LABEL_1": 1, "LABEL_2": 2},
78
+ id2label={0: "LABEL_0", 1: "LABEL_1", 2: "LABEL_2"},
79
+ task_specific_params={"zero-shot-classification": {"entailment_id": 2}},
80
+ num_labels=3,
81
+ ),
82
+ )
83
+
84
+ monkeypatch.setattr(service, "_load_model", lambda: (tokenizer, model))
85
+ monkeypatch.setattr(classifier_module.settings, "classifier_entailment_label_id", None)
86
+
87
+ predicted = service.classify(
88
+ "The story is mostly about football transfers.",
89
+ ["finance", "sport"],
90
+ )
91
+
92
+ assert predicted == "sport"
93
+
94
+
95
+ def test_classify_uses_explicit_entailment_id_setting_when_mapping_is_missing(monkeypatch):
96
+ service = classifier_module.ClassifierService()
97
+ tokenizer = _FakeTokenizer()
98
+ model = _FakeInferenceModel(
99
+ logits=torch.tensor(
100
+ [
101
+ [2.0, 0.3], # finance -> low entailment
102
+ [0.2, 3.4], # sport -> highest entailment
103
+ ]
104
+ ),
105
+ config=SimpleNamespace(
106
+ label2id={"NEGATIVE": 0, "POSITIVE": 1},
107
+ id2label={0: "NEGATIVE", 1: "POSITIVE"},
108
+ num_labels=2,
109
+ ),
110
+ )
111
+
112
+ monkeypatch.setattr(service, "_load_model", lambda: (tokenizer, model))
113
+ monkeypatch.setattr(classifier_module.settings, "classifier_entailment_label_id", 1)
114
+
115
+ predicted = service.classify(
116
+ "The story is mostly about football transfers.",
117
+ ["finance", "sport"],
118
+ )
119
+
120
+ assert predicted == "sport"
121
+
122
+
123
+ def test_classify_falls_back_to_mnli_entailment_index_for_generic_three_label_configs(monkeypatch):
124
+ service = classifier_module.ClassifierService()
125
+ tokenizer = _FakeTokenizer()
126
+ model = _FakeInferenceModel(
127
+ logits=torch.tensor(
128
+ [
129
+ [2.3, 0.6, 0.8], # finance -> low entailment
130
+ [0.4, 0.8, 3.9], # sport -> highest entailment
131
+ ]
132
+ ),
133
+ config=SimpleNamespace(
134
+ label2id={"LABEL_0": 0, "LABEL_1": 1, "LABEL_2": 2},
135
+ id2label={0: "LABEL_0", 1: "LABEL_1", 2: "LABEL_2"},
136
+ num_labels=3,
137
+ ),
138
+ )
139
+
140
+ monkeypatch.setattr(service, "_load_model", lambda: (tokenizer, model))
141
+ monkeypatch.setattr(classifier_module.settings, "classifier_entailment_label_id", None)
142
+
143
+ predicted = service.classify(
144
+ "The story is mostly about football transfers.",
145
+ ["finance", "sport"],
146
+ )
147
+
148
+ assert predicted == "sport"
149
+
150
+
151
+ def test_classify_falls_back_to_mnli_entailment_index_for_missing_label_mapping(monkeypatch):
152
+ service = classifier_module.ClassifierService()
153
+ tokenizer = _FakeTokenizer()
154
+ model = _FakeInferenceModel(
155
+ logits=torch.tensor(
156
+ [
157
+ [1.8, 0.4, 0.5], # finance -> low entailment
158
+ [0.2, 0.5, 3.6], # sport -> highest entailment
159
+ ]
160
+ ),
161
+ config=SimpleNamespace(
162
+ label2id={},
163
+ id2label={},
164
+ num_labels=3,
165
+ ),
166
+ )
167
+
168
+ monkeypatch.setattr(service, "_load_model", lambda: (tokenizer, model))
169
+ monkeypatch.setattr(classifier_module.settings, "classifier_entailment_label_id", None)
170
+
171
+ predicted = service.classify(
172
+ "The story is mostly about football transfers.",
173
+ ["finance", "sport"],
174
+ )
175
+
176
+ assert predicted == "sport"
177
+
178
+
179
  def test_model_quantization_falls_back_to_non_quantized_model(monkeypatch):
180
  service = classifier_module.ClassifierService()
181
  fake_model = _FakeLoadModel()
tests/test_routes.py CHANGED
@@ -27,7 +27,7 @@ def test_language_endpoint_contract(monkeypatch):
27
 
28
 
29
  def test_labels_config_roundtrip():
30
- response = client.post("/configlabel", json={"text": "tech, health, legal"})
31
  assert response.status_code == 200
32
  assert response.json() == ["tech", "health", "legal"]
33
 
@@ -36,6 +36,36 @@ def test_labels_config_roundtrip():
36
  assert get_response.json() == ["tech", "health", "legal"]
37
 
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  def test_transform_file_contract(monkeypatch):
40
  monkeypatch.setattr(classification_pipeline, "transform_file", lambda filename, path: "extracted content")
41
 
 
27
 
28
 
29
  def test_labels_config_roundtrip():
30
+ response = client.post("/configlabel", json={"labels": ["tech", "health", "legal"]})
31
  assert response.status_code == 200
32
  assert response.json() == ["tech", "health", "legal"]
33
 
 
36
  assert get_response.json() == ["tech", "health", "legal"]
37
 
38
 
39
+ def test_labels_config_accepts_labels_list_payload():
40
+ response = client.post("/configlabel", json={"labels": ["tech", "health", "legal"]})
41
+ assert response.status_code == 200
42
+ assert response.json() == ["tech", "health", "legal"]
43
+
44
+
45
+ def test_labels_config_rejects_empty_labels():
46
+ response = client.post("/configlabel", json={"labels": [" ", ""]})
47
+ assert response.status_code == 400
48
+ assert response.json() == {"detail": "At least one label is required"}
49
+
50
+
51
+ def test_labels_config_rejects_missing_labels():
52
+ response = client.post("/configlabel", json={})
53
+ assert response.status_code == 422
54
+ assert "labels" in response.text
55
+
56
+
57
+ def test_labels_config_rejects_text_field():
58
+ response = client.post("/configlabel", json={"text": "tech,health"})
59
+ assert response.status_code == 422
60
+ assert "extra_forbidden" in response.text
61
+
62
+
63
+ def test_labels_config_rejects_texts_field():
64
+ response = client.post("/configlabel", json={"texts": ["tech,health"]})
65
+ assert response.status_code == 422
66
+ assert "extra_forbidden" in response.text
67
+
68
+
69
  def test_transform_file_contract(monkeypatch):
70
  monkeypatch.setattr(classification_pipeline, "transform_file", lambda filename, path: "extracted content")
71