kadarakos commited on
Commit
19d3bef
·
1 Parent(s): 772f25e

JIT for both labeler and detector + more drift metrics

Browse files
Files changed (1) hide show
  1. src/mentioned/app.py +90 -25
src/mentioned/app.py CHANGED
@@ -1,5 +1,7 @@
 
1
  import os
2
  import gc
 
3
  import nltk
4
  from contextlib import asynccontextmanager
5
  from typing import List
@@ -8,14 +10,17 @@ from fastapi import FastAPI
8
  from pydantic import BaseModel
9
  from transformers import AutoTokenizer
10
  from prometheus_fastapi_instrumentator import Instrumentator
11
- from prometheus_client import Histogram
12
  from nltk.tokenize import word_tokenize
13
 
14
  # Internal package imports
15
  from mentioned.inference import (
16
  create_inference_model,
17
- compile_inference_model,
 
18
  ONNXMentionDetectorPipeline,
 
 
19
  )
20
 
21
 
@@ -32,18 +37,47 @@ class TextRequest(BaseModel):
32
  texts: List[str]
33
 
34
 
35
- MODEL_CONFIDENCE = Histogram(
36
  "mention_detector_confidence",
37
- "Distribution of prediction confidence scores",
38
  buckets=[0.1, 0.3, 0.5, 0.7, 0.8, 0.9, 1.0],
39
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  MENTIONS_PER_DOC = Histogram(
41
- "mention_detector_density",
42
  "Number of mentions detected per document",
43
  buckets=[0, 1, 2, 5, 10, 20, 50],
44
  )
45
- REPO_ID = os.getenv("REPO_ID", "kadarakos/mention-detector-poc-dry-run")
46
- ENGINE_DIR = "engine"
 
 
 
 
 
 
 
 
 
 
47
  MODEL_PATH = os.path.join(ENGINE_DIR, "model.onnx")
48
 
49
  state = {}
@@ -52,46 +86,77 @@ setup_nltk()
52
 
53
  @asynccontextmanager
54
  async def lifespan(app: FastAPI):
55
- """Handles the JIT compilation and RAM cleanup."""
 
56
  if not os.path.exists(MODEL_PATH):
57
- print(f"🏗️ Engine not found. Compiling from {REPO_ID}...")
58
- # create_inference_model respects HF_TOKEN env var automatically
59
- torch_model = create_inference_model(REPO_ID, "model_v1")
60
- compile_inference_model(torch_model, ENGINE_DIR)
 
 
 
 
 
 
 
 
61
  tokenizer = torch_model.tokenizer
62
  del torch_model
63
  gc.collect()
64
- print("✅ Compilation complete.")
 
 
 
 
 
 
 
65
  else:
66
- print("🚀 Loading existing ONNX engine...")
67
- tokenizer = AutoTokenizer.from_pretrained(ENGINE_DIR)
68
 
69
- state["pipeline"] = ONNXMentionDetectorPipeline(MODEL_PATH, tokenizer)
70
  yield
71
  state.clear()
72
 
73
-
74
  app = FastAPI(lifespan=lifespan)
75
  Instrumentator().instrument(app).expose(app)
76
 
77
 
78
  @app.post("/predict")
79
  async def predict(request: TextRequest):
80
- pipeline = state["pipeline"]
81
  docs = [word_tokenize(t) for t in request.texts]
82
- batch_results = pipeline.predict(docs)
83
- for doc_mentions in batch_results:
84
- MENTIONS_PER_DOC.observe(len(doc_mentions))
85
- for m in doc_mentions:
86
- MODEL_CONFIDENCE.observe(m["score"])
 
 
 
 
 
 
 
 
87
 
88
- return {"results": batch_results}
 
 
 
 
 
 
 
 
 
 
 
89
 
90
 
91
  @app.get("/")
92
  def home():
93
  return {
94
- "message": "Mention Detector API",
95
  "docs": "/docs",
96
  "metrics": "/metrics",
97
  }
 
1
+ import time
2
  import os
3
  import gc
4
+ import json
5
  import nltk
6
  from contextlib import asynccontextmanager
7
  from typing import List
 
10
  from pydantic import BaseModel
11
  from transformers import AutoTokenizer
12
  from prometheus_fastapi_instrumentator import Instrumentator
13
+ from prometheus_client import Histogram, Counter
14
  from nltk.tokenize import word_tokenize
15
 
16
  # Internal package imports
17
  from mentioned.inference import (
18
  create_inference_model,
19
+ compile_detector,
20
+ compile_labeler,
21
  ONNXMentionDetectorPipeline,
22
+ ONNXMentionLabelerPipeline,
23
+ InferenceMentionLabeler
24
  )
25
 
26
 
 
37
  texts: List[str]
38
 
39
 
40
+ MENTION_CONFIDENCE = Histogram(
41
  "mention_detector_confidence",
42
+ "Distribution of prediction confidence scores for detector.",
43
  buckets=[0.1, 0.3, 0.5, 0.7, 0.8, 0.9, 1.0],
44
  )
45
+ ENTITY_CONFIDENCE = Histogram(
46
+ "entity_labeler_confidence",
47
+ "Distribution of prediction confidence scores for labeler."
48
+ )
49
+ ENTITY_LABEL_COUNTS = Counter(
50
+ "entity_label_total",
51
+ "Total count of predicted entity labels",
52
+ ["label_name"]
53
+ )
54
+ INPUT_TOKENS = Histogram(
55
+ "mention_input_tokens_count",
56
+ "Number of tokens per input document",
57
+ buckets=[1, 5, 10, 20, 50, 100, 250, 500]
58
+ )
59
+ MENTION_DENSITY = Histogram(
60
+ "mention_density_ratio",
61
+ "Ratio of mentions to total tokens in a document",
62
+ buckets=[0.01, 0.05, 0.1, 0.2, 0.5]
63
+ )
64
  MENTIONS_PER_DOC = Histogram(
65
+ "mention_detector_count",
66
  "Number of mentions detected per document",
67
  buckets=[0, 1, 2, 5, 10, 20, 50],
68
  )
69
+
70
+ INFERENCE_LATENCY = Histogram(
71
+ "inference_duration_seconds",
72
+ "Time spent in the model prediction pipeline",
73
+ buckets=[0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0]
74
+ )
75
+
76
+ REPO_ID = os.getenv("REPO_ID", "kadarakos/entity-labeler-poc")
77
+ ENCODER_ID = os.getenv("ENCODER_ID", "distilroberta-base")
78
+ MODEL_FACTORY = os.getenv("MODEL_FACTORY", "model_v2")
79
+ DATA_FACTORY = os.getenv("DATA_FACTORY", "litbank_entities")
80
+ ENGINE_DIR = "model_v2_artifact"
81
  MODEL_PATH = os.path.join(ENGINE_DIR, "model.onnx")
82
 
83
  state = {}
 
86
 
87
  @asynccontextmanager
88
  async def lifespan(app: FastAPI):
89
+ """JIT compilation and loading for both Detector and Labeler."""
90
+
91
  if not os.path.exists(MODEL_PATH):
92
+ print(f"🏗️ Engine not found. Compiling {MODEL_FACTORY} from {REPO_ID}...")
93
+ torch_model = create_inference_model(REPO_ID, ENCODER_ID, MODEL_FACTORY, DATA_FACTORY)
94
+
95
+ if isinstance(torch_model, InferenceMentionLabeler):
96
+ compile_labeler(torch_model, ENGINE_DIR)
97
+ with open(os.path.join(ENGINE_DIR, "config.json"), "w") as f:
98
+ json.dump({"id2label": torch_model.id2label, "type": "labeler"}, f)
99
+ else:
100
+ compile_detector(torch_model, ENGINE_DIR)
101
+ with open(os.path.join(ENGINE_DIR, "config.json"), "w") as f:
102
+ json.dump({"type": "detector"}, f)
103
+
104
  tokenizer = torch_model.tokenizer
105
  del torch_model
106
  gc.collect()
107
+
108
+ tokenizer = AutoTokenizer.from_pretrained(ENGINE_DIR)
109
+ with open(os.path.join(ENGINE_DIR, "config.json"), "r") as f:
110
+ config = json.load(f)
111
+
112
+ if config.get("type") == "labeler":
113
+ id2label = {int(k): v for k, v in config["id2label"].items()}
114
+ state["pipeline"] = ONNXMentionLabelerPipeline(MODEL_PATH, tokenizer, id2label)
115
  else:
116
+ state["pipeline"] = ONNXMentionDetectorPipeline(MODEL_PATH, tokenizer)
 
117
 
 
118
  yield
119
  state.clear()
120
 
 
121
  app = FastAPI(lifespan=lifespan)
122
  Instrumentator().instrument(app).expose(app)
123
 
124
 
125
  @app.post("/predict")
126
  async def predict(request: TextRequest):
 
127
  docs = [word_tokenize(t) for t in request.texts]
128
+ start_time = time.perf_counter()
129
+ results = state["pipeline"].predict(docs)
130
+ INFERENCE_LATENCY.observe(time.perf_counter() - start_time)
131
+
132
+ for doc, doc_mentions in zip(docs, results):
133
+ token_count = len(doc)
134
+ mention_count = len(doc_mentions)
135
+
136
+ # Input/Density metrics
137
+ INPUT_TOKENS.observe(token_count)
138
+ MENTIONS_PER_DOC.observe(mention_count)
139
+ if token_count > 0:
140
+ MENTION_DENSITY.observe(mention_count / token_count)
141
 
142
+ for m in doc_mentions:
143
+ # Basic detector confidence
144
+ MENTION_CONFIDENCE.observe(m.get("score", 0))
145
+
146
+ # Labeler specific metrics
147
+ if "label" in m:
148
+ ENTITY_LABEL_COUNTS.labels(label_name=m["label"]).inc()
149
+ # Ensure we only observe label_score if it exists in the output
150
+ if "label_score" in m:
151
+ ENTITY_CONFIDENCE.observe(m["label_score"])
152
+
153
+ return {"results": results, "model_repo": REPO_ID}
154
 
155
 
156
  @app.get("/")
157
  def home():
158
  return {
159
+ "message": "Mention Detector and Labeler API.",
160
  "docs": "/docs",
161
  "metrics": "/metrics",
162
  }