| | import time |
| | import os |
| | import gc |
| | import json |
| | import nltk |
| | from contextlib import asynccontextmanager |
| | from typing import List |
| |
|
| | from fastapi import FastAPI |
| | from pydantic import BaseModel |
| | from transformers import AutoTokenizer |
| | from prometheus_fastapi_instrumentator import Instrumentator |
| | from prometheus_client import Histogram, Counter |
| | from nltk.tokenize import word_tokenize |
| |
|
| | |
| | from mentioned.inference import ( |
| | create_inference_model, |
| | compile_detector, |
| | compile_labeler, |
| | ONNXMentionDetectorPipeline, |
| | ONNXMentionLabelerPipeline, |
| | InferenceMentionLabeler |
| | ) |
| |
|
| |
|
| | def setup_nltk(): |
| | resources = ["punkt", "punkt_tab"] |
| | for res in resources: |
| | try: |
| | nltk.data.find(f"tokenizers/{res}") |
| | except LookupError: |
| | nltk.download(res) |
| |
|
| |
|
| | class TextRequest(BaseModel): |
| | texts: List[str] |
| |
|
| |
|
| | MENTION_CONFIDENCE = Histogram( |
| | "mention_detector_confidence", |
| | "Distribution of prediction confidence scores for detector.", |
| | buckets=[0.1, 0.3, 0.5, 0.7, 0.8, 0.9, 1.0], |
| | ) |
| | ENTITY_CONFIDENCE = Histogram( |
| | "entity_labeler_confidence", |
| | "Distribution of prediction confidence scores for labeler." |
| | ) |
| | ENTITY_LABEL_COUNTS = Counter( |
| | "entity_label_total", |
| | "Total count of predicted entity labels", |
| | ["label_name"] |
| | ) |
| | INPUT_TOKENS = Histogram( |
| | "mention_input_tokens_count", |
| | "Number of tokens per input document", |
| | buckets=[1, 5, 10, 20, 50, 100, 250, 500] |
| | ) |
| | MENTION_DENSITY = Histogram( |
| | "mention_density_ratio", |
| | "Ratio of mentions to total tokens in a document", |
| | buckets=[0.01, 0.05, 0.1, 0.2, 0.5] |
| | ) |
| | MENTIONS_PER_DOC = Histogram( |
| | "mention_detector_count", |
| | "Number of mentions detected per document", |
| | buckets=[0, 1, 2, 5, 10, 20, 50], |
| | ) |
| |
|
| | INFERENCE_LATENCY = Histogram( |
| | "inference_duration_seconds", |
| | "Time spent in the model prediction pipeline", |
| | buckets=[0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0] |
| | ) |
| |
|
| | REPO_ID = os.getenv("REPO_ID", "kadarakos/entity-labeler-poc") |
| | ENCODER_ID = os.getenv("ENCODER_ID", "distilroberta-base") |
| | MODEL_FACTORY = os.getenv("MODEL_FACTORY", "model_v2") |
| | DATA_FACTORY = os.getenv("DATA_FACTORY", "litbank_entities") |
| | ENGINE_DIR = "model_v2_artifact" |
| | MODEL_PATH = os.path.join(ENGINE_DIR, "model.onnx") |
| |
|
| | state = {} |
| | setup_nltk() |
| |
|
| |
|
| | @asynccontextmanager |
| | async def lifespan(app: FastAPI): |
| | """JIT compilation and loading for both Detector and Labeler.""" |
| |
|
| | if not os.path.exists(MODEL_PATH): |
| | print(f"🏗️ Engine not found. Compiling {MODEL_FACTORY} from {REPO_ID}...") |
| | torch_model = create_inference_model(REPO_ID, ENCODER_ID, MODEL_FACTORY, DATA_FACTORY) |
| |
|
| | if isinstance(torch_model, InferenceMentionLabeler): |
| | compile_labeler(torch_model, ENGINE_DIR) |
| | with open(os.path.join(ENGINE_DIR, "config.json"), "w") as f: |
| | json.dump({"id2label": torch_model.id2label, "type": "labeler"}, f) |
| | else: |
| | compile_detector(torch_model, ENGINE_DIR) |
| | with open(os.path.join(ENGINE_DIR, "config.json"), "w") as f: |
| | json.dump({"type": "detector"}, f) |
| |
|
| | tokenizer = torch_model.tokenizer |
| | del torch_model |
| | gc.collect() |
| |
|
| | tokenizer = AutoTokenizer.from_pretrained(ENGINE_DIR) |
| | with open(os.path.join(ENGINE_DIR, "config.json"), "r") as f: |
| | config = json.load(f) |
| |
|
| | if config.get("type") == "labeler": |
| | id2label = {int(k): v for k, v in config["id2label"].items()} |
| | state["pipeline"] = ONNXMentionLabelerPipeline(MODEL_PATH, tokenizer, id2label) |
| | else: |
| | state["pipeline"] = ONNXMentionDetectorPipeline(MODEL_PATH, tokenizer) |
| |
|
| | yield |
| | state.clear() |
| |
|
| | app = FastAPI(lifespan=lifespan) |
| | Instrumentator().instrument(app).expose(app) |
| |
|
| |
|
| | @app.post("/predict") |
| | async def predict(request: TextRequest): |
| | docs = [word_tokenize(t) for t in request.texts] |
| | start_time = time.perf_counter() |
| | results = state["pipeline"].predict(docs) |
| | INFERENCE_LATENCY.observe(time.perf_counter() - start_time) |
| |
|
| | for doc, doc_mentions in zip(docs, results): |
| | token_count = len(doc) |
| | mention_count = len(doc_mentions) |
| | |
| | |
| | INPUT_TOKENS.observe(token_count) |
| | MENTIONS_PER_DOC.observe(mention_count) |
| | if token_count > 0: |
| | MENTION_DENSITY.observe(mention_count / token_count) |
| |
|
| | for m in doc_mentions: |
| | |
| | MENTION_CONFIDENCE.observe(m.get("score", 0)) |
| | |
| | |
| | if "label" in m: |
| | ENTITY_LABEL_COUNTS.labels(label_name=m["label"]).inc() |
| | |
| | if "label_score" in m: |
| | ENTITY_CONFIDENCE.observe(m["label_score"]) |
| | |
| | return {"results": results, "model_repo": REPO_ID} |
| |
|
| |
|
| | @app.get("/") |
| | def home(): |
| | return { |
| | "message": "Mention Detector and Labeler API.", |
| | "docs": "/docs", |
| | "metrics": "/metrics", |
| | } |
| |
|