AyoubChLin commited on
Commit
59b46a2
·
1 Parent(s): efddb2f

feat: enhance classifier service with model warmup and dynamic quantization

Browse files
Files changed (3) hide show
  1. .gitignore +3 -1
  2. app/main.py +15 -0
  3. app/services/classifier_service.py +42 -13
.gitignore CHANGED
@@ -4,4 +4,6 @@ __pycache__/
4
  *.pyc
5
  .pytest_cache/
6
  static/uploads/
7
- static/*
 
 
 
4
  *.pyc
5
  .pytest_cache/
6
  static/uploads/
7
+ static/*
8
+ venv
9
+ .qwen
app/main.py CHANGED
@@ -1,8 +1,14 @@
 
 
1
  from fastapi import FastAPI
2
  from fastapi.staticfiles import StaticFiles
3
 
4
  from app.api.router import api_router
5
  from app.core.config import settings
 
 
 
 
6
 
7
  settings.static_dir.mkdir(parents=True, exist_ok=True)
8
  settings.upload_dir.mkdir(parents=True, exist_ok=True)
@@ -12,6 +18,15 @@ app.mount("/static", StaticFiles(directory=str(settings.static_dir)), name="stat
12
  app.include_router(api_router)
13
 
14
 
 
 
 
 
 
 
 
 
 
15
  @app.get("/endpoint/")
16
  def list_endpoints() -> list[dict]:
17
  endpoints = []
 
1
+ import logging
2
+
3
  from fastapi import FastAPI
4
  from fastapi.staticfiles import StaticFiles
5
 
6
  from app.api.router import api_router
7
  from app.core.config import settings
8
+ from app.core.exceptions import ClassificationError
9
+ from app.services.classifier_service import classifier_service
10
+
11
+ logger = logging.getLogger(__name__)
12
 
13
  settings.static_dir.mkdir(parents=True, exist_ok=True)
14
  settings.upload_dir.mkdir(parents=True, exist_ok=True)
 
18
  app.include_router(api_router)
19
 
20
 
21
+ @app.on_event("startup")
22
+ def preload_classifier_model() -> None:
23
+ try:
24
+ classifier_service.warmup()
25
+ logger.info("Classifier model preloaded on startup")
26
+ except ClassificationError:
27
+ logger.exception("Classifier model warmup failed")
28
+
29
+
30
  @app.get("/endpoint/")
31
  def list_endpoints() -> list[dict]:
32
  endpoints = []
app/services/classifier_service.py CHANGED
@@ -1,6 +1,7 @@
1
  from typing import Any
2
 
3
- from transformers import pipeline
 
4
 
5
  from app.core.config import settings
6
  from app.core.exceptions import ClassificationError
@@ -8,31 +9,59 @@ from app.core.exceptions import ClassificationError
8
 
9
  class ClassifierService:
10
  def __init__(self) -> None:
11
- self._pipeline: Any | None = None
 
12
 
13
- def _get_pipeline(self) -> Any:
14
- if self._pipeline is None:
15
  try:
16
- self._pipeline = pipeline(
17
- "zero-shot-classification",
18
- model=settings.classifier_model,
 
 
 
 
 
 
 
 
19
  )
 
 
 
20
  except Exception as exc:
21
- raise ClassificationError("Unable to initialize classifier pipeline") from exc
22
- return self._pipeline
 
 
 
 
23
 
24
  def classify(self, text: str, labels: list[str]) -> str:
25
  if not labels:
26
  raise ClassificationError("No labels configured")
27
 
 
 
28
  try:
29
- result = self._get_pipeline()(text, labels, multi_label=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  except Exception as exc:
31
  raise ClassificationError("Classifier prediction failed") from exc
32
 
33
- if isinstance(result, dict) and "labels" in result and result["labels"]:
34
- return result["labels"][0]
35
-
36
  raise ClassificationError("Classifier did not return a valid label")
37
 
38
 
 
1
  from typing import Any
2
 
3
+ import torch
4
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
5
 
6
  from app.core.config import settings
7
  from app.core.exceptions import ClassificationError
 
9
 
10
  class ClassifierService:
11
  def __init__(self) -> None:
12
+ self._tokenizer: Any | None = None
13
+ self._model: Any | None = None
14
 
15
+ def _load_model(self) -> tuple[Any, Any]:
16
+ if self._tokenizer is None or self._model is None:
17
  try:
18
+ tokenizer = AutoTokenizer.from_pretrained(settings.classifier_model)
19
+
20
+ model = AutoModelForSequenceClassification.from_pretrained(settings.classifier_model)
21
+ model.eval()
22
+ model.to("cpu")
23
+
24
+ # Dynamic INT8 quantization for CPU inference.
25
+ quantized_model = torch.quantization.quantize_dynamic(
26
+ model,
27
+ {torch.nn.Linear},
28
+ dtype=torch.qint8,
29
  )
30
+
31
+ self._tokenizer = tokenizer
32
+ self._model = quantized_model
33
  except Exception as exc:
34
+ raise ClassificationError("Unable to initialize classifier model") from exc
35
+
36
+ return self._tokenizer, self._model
37
+
38
+ def warmup(self) -> None:
39
+ self._load_model()
40
 
41
  def classify(self, text: str, labels: list[str]) -> str:
42
  if not labels:
43
  raise ClassificationError("No labels configured")
44
 
45
+ tokenizer, model = self._load_model()
46
+
47
  try:
48
+ inputs = tokenizer(
49
+ text,
50
+ padding=True,
51
+ truncation=True,
52
+ return_tensors="pt",
53
+ )
54
+
55
+ with torch.no_grad():
56
+ logits = model(**inputs).logits
57
+
58
+ predicted_class_id = logits.argmax(dim=-1).item()
59
+ predicted_label = model.config.id2label.get(predicted_class_id)
60
+ if predicted_label:
61
+ return str(predicted_label)
62
  except Exception as exc:
63
  raise ClassificationError("Classifier prediction failed") from exc
64
 
 
 
 
65
  raise ClassificationError("Classifier did not return a valid label")
66
 
67