kadarakos commited on
Commit
fd47265
Β·
1 Parent(s): f4ae92f

prometheus

Browse files
Files changed (2) hide show
  1. Dockerfile +26 -0
  2. src/mentioned/app.py +54 -40
Dockerfile ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM ghcr.io/astral-sh/uv:python3.12-bookworm-slim
2
+
3
+ # Stay in root to keep paths simple
4
+ WORKDIR /
5
+
6
+ # 1. Install dependencies (Cached layer)
7
+ # We need --extra train because we need Torch for the initial compilation
8
+ COPY pyproject.toml uv.lock ./
9
+ RUN uv sync --frozen --no-install-project --extra train
10
+
11
+ # 2. Pre-bake NLTK data so it doesn't download on every request
12
+ RUN uv run python -m nltk.downloader punkt punkt_tab
13
+
14
+ # 3. Copy only the source code (Excludes ONNX via .dockerignore)
15
+ COPY src ./src
16
+ COPY README.md ./
17
+
18
+ # 4. Final project install
19
+ RUN uv sync --frozen --extra train
20
+
21
+ # 5. HF Space defaults
22
+ ENV PORT=7860
23
+ EXPOSE 7860
24
+
25
+ # Run the app. The 'lifespan' in mentioned.app will handle the download/ONNX export.
26
+ CMD ["uv", "run", "python", "-m", "uvicorn", "mentioned.app:app", "--host", "0.0.0.0", "--port", "7860"]
src/mentioned/app.py CHANGED
@@ -3,81 +3,95 @@ import gc
3
  import nltk
4
  from contextlib import asynccontextmanager
5
  from typing import List
6
- from nltk.tokenize import word_tokenize
7
 
8
  from fastapi import FastAPI
9
  from pydantic import BaseModel
10
  from transformers import AutoTokenizer
 
 
 
11
 
12
-
13
- # Internal imports from your package
14
  from mentioned.inference import (
15
  create_inference_model,
16
  compile_inference_model,
17
  ONNXMentionDetectorPipeline,
18
  )
19
 
20
- REPO_ID = "kadarakos/mention-detector-poc-dry-run"
21
- ONNX_DIR = "model_v1_onnx"
22
- MODEL_PATH = os.path.join(ONNX_DIR, "model.onnx")
23
-
24
- # We use a global dict to store the pipeline after the heavy startup
25
- state = {}
26
-
27
 
28
- def ensure_nltk_resources():
29
  resources = ["punkt", "punkt_tab"]
30
  for res in resources:
31
  try:
32
  nltk.data.find(f"tokenizers/{res}")
33
  except LookupError:
34
- print(f"gettin' {res} for ya...")
35
  nltk.download(res)
36
 
37
 
38
- ensure_nltk_resources()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
 
41
  @asynccontextmanager
42
  async def lifespan(app: FastAPI):
 
43
  if not os.path.exists(MODEL_PATH):
44
- print(f"πŸ—οΈ Compiling model from {REPO_ID}...")
 
45
  torch_model = create_inference_model(REPO_ID, "model_v1")
46
- compile_inference_model(torch_model, MODEL_PATH)
47
- # Keep tokenizer, evict Torch
48
  tokenizer = torch_model.tokenizer
49
  del torch_model
50
  gc.collect()
51
- print("βœ… Compilation complete. RAM cleared.")
52
  else:
53
- print("πŸš€ Loading existing ONNX model...")
54
- tokenizer = AutoTokenizer.from_pretrained(ONNX_DIR)
55
-
56
- state["pipeline"] = ONNXMentionDetectorPipeline(
57
- MODEL_PATH,
58
- tokenizer,
59
- # TODO Don't hardcode!
60
- threshold=0.3,
61
- )
62
  yield
63
  state.clear()
64
 
65
  app = FastAPI(lifespan=lifespan)
66
-
67
-
68
- class TextRequest(BaseModel):
69
- texts: List[str]
70
 
71
 
72
  @app.post("/predict")
73
  async def predict(request: TextRequest):
74
- docs = [word_tokenize(text) for text in request.texts]
75
- # docs = [text.split() for text in request.texts]
76
- results = state["pipeline"].predict(docs)
77
- print("YEAH")
78
- return {"results": results}
79
-
80
-
81
- @app.get("/health")
82
- def health():
83
- return {"status": "ok", "onnx_exists": os.path.exists(MODEL_PATH)}
 
 
 
 
 
 
 
 
 
3
  import nltk
4
  from contextlib import asynccontextmanager
5
  from typing import List
 
6
 
7
  from fastapi import FastAPI
8
  from pydantic import BaseModel
9
  from transformers import AutoTokenizer
10
+ from prometheus_fastapi_instrumentator import Instrumentator
11
+ from prometheus_client import Histogram
12
+ from nltk.tokenize import word_tokenize
13
 
14
+ # Internal package imports
 
15
  from mentioned.inference import (
16
  create_inference_model,
17
  compile_inference_model,
18
  ONNXMentionDetectorPipeline,
19
  )
20
 
 
 
 
 
 
 
 
21
 
22
+ def setup_nltk():
23
  resources = ["punkt", "punkt_tab"]
24
  for res in resources:
25
  try:
26
  nltk.data.find(f"tokenizers/{res}")
27
  except LookupError:
 
28
  nltk.download(res)
29
 
30
 
31
+
32
+ class TextRequest(BaseModel):
33
+ texts: List[str]
34
+
35
+
36
+ MODEL_CONFIDENCE = Histogram(
37
+ "mention_detector_confidence",
38
+ "Distribution of prediction confidence scores",
39
+ buckets=[0.1, 0.3, 0.5, 0.7, 0.8, 0.9, 1.0]
40
+ )
41
+ MENTIONS_PER_DOC = Histogram(
42
+ "mention_detector_density",
43
+ "Number of mentions detected per document",
44
+ buckets=[0, 1, 2, 5, 10, 20, 50]
45
+ )
46
+ REPO_ID = os.getenv("REPO_ID", "kadarakos/mention-detector-poc-dry-run")
47
+ ENGINE_DIR = "engine"
48
+ MODEL_PATH = os.path.join(ENGINE_DIR, "model.onnx")
49
+
50
+ state = {}
51
+ setup_nltk()
52
 
53
 
54
  @asynccontextmanager
55
  async def lifespan(app: FastAPI):
56
+ """Handles the JIT compilation and RAM cleanup."""
57
  if not os.path.exists(MODEL_PATH):
58
+ print(f"πŸ—οΈ Engine not found. Compiling from {REPO_ID}...")
59
+ # create_inference_model respects HF_TOKEN env var automatically
60
  torch_model = create_inference_model(REPO_ID, "model_v1")
61
+ compile_inference_model(torch_model, ENGINE_DIR)
 
62
  tokenizer = torch_model.tokenizer
63
  del torch_model
64
  gc.collect()
65
+ print("βœ… Compilation complete.")
66
  else:
67
+ print("πŸš€ Loading existing ONNX engine...")
68
+ tokenizer = AutoTokenizer.from_pretrained(ENGINE_DIR)
69
+
70
+ state["pipeline"] = ONNXMentionDetectorPipeline(MODEL_PATH, tokenizer)
 
 
 
 
 
71
  yield
72
  state.clear()
73
 
74
  app = FastAPI(lifespan=lifespan)
75
+ Instrumentator().instrument(app).expose(app)
 
 
 
76
 
77
 
78
  @app.post("/predict")
79
  async def predict(request: TextRequest):
80
+ pipeline = state["pipeline"]
81
+ docs = [word_tokenize(t) for t in request.texts]
82
+ batch_results = pipeline.predict(docs)
83
+ for doc_mentions in batch_results:
84
+ MENTIONS_PER_DOC.observe(len(doc_mentions))
85
+ for m in doc_mentions:
86
+ MODEL_CONFIDENCE.observe(m["score"])
87
+
88
+ return {"results": batch_results}
89
+
90
+
91
+ @app.get("/")
92
+ def home():
93
+ return {
94
+ "message": "Mention Detector API",
95
+ "docs": "/docs",
96
+ "metrics": "/metrics",
97
+ }