use anyhow::{anyhow, Context, Result}; use axum::{ body::Bytes, extract::{DefaultBodyLimit, Multipart, Query, State}, http::StatusCode, response::IntoResponse, routing::{get, post}, Json, Router, }; use crossbeam_channel::{bounded, Receiver, Sender}; use serde::{Deserialize, Serialize}; use std::{ env, ffi::CString, net::SocketAddr, os::raw::{c_char, c_float, c_int}, path::Path, sync::{ atomic::{AtomicU64, Ordering}, Arc, }, thread, time::{Duration, Instant}, }; use tokio::{sync::oneshot, task}; #[repr(C)] struct TrtContextOpaque { _private: [u8; 0], } extern "C" { fn trt_create(engine_path: *const c_char) -> *mut TrtContextOpaque; fn trt_destroy(ctx: *mut TrtContextOpaque); fn trt_max_batch(ctx: *mut TrtContextOpaque) -> c_int; fn trt_infer( ctx: *mut TrtContextOpaque, image: *const c_float, im_shape: *const c_float, scale_factor: *const c_float, batch: c_int, boxes_out: *mut c_float, counts_out: *mut i32, ) -> c_int; } const LABELS: [&str; 25] = [ "abstract", "algorithm", "aside_text", "chart", "content", "display_formula", "doc_title", "figure_title", "footer", "footer_image", "footnote", "formula_number", "header", "header_image", "image", "inline_formula", "number", "paragraph_title", "reference", "reference_content", "seal", "table", "text", "vertical_text", "vision_footnote", ]; struct Engine(*mut TrtContextOpaque); unsafe impl Send for Engine {} impl Engine { fn load(path: &str) -> Result { let c_path = CString::new(path)?; let ptr = unsafe { trt_create(c_path.as_ptr()) }; if ptr.is_null() { return Err(anyhow!("failed to load TensorRT engine: {path}")); } Ok(Self(ptr)) } fn max_batch(&self) -> usize { unsafe { trt_max_batch(self.0).max(1) as usize } } fn infer( &mut self, image: &[f32], im_shape: &[f32], scale_factor: &[f32], batch: usize, ) -> Result<(Vec, Vec)> { let mut boxes = vec![0.0f32; batch * 300 * 7]; let mut counts = vec![0i32; batch]; let code = unsafe { trt_infer( self.0, image.as_ptr(), im_shape.as_ptr(), scale_factor.as_ptr(), batch as c_int, boxes.as_mut_ptr(), counts.as_mut_ptr(), ) }; if code != 0 { return Err(anyhow!("trt_infer failed with code {code}")); } Ok((boxes, counts)) } } impl Drop for Engine { fn drop(&mut self) { unsafe { trt_destroy(self.0) } } } #[derive(Clone)] struct PageInput { image: Arc>, im_shape: [f32; 2], scale_factor: [f32; 2], } #[derive(Default)] struct Metrics { requests: AtomicU64, batches: AtomicU64, pages: AtomicU64, errors: AtomicU64, total_batch_wait_us: AtomicU64, total_infer_us: AtomicU64, } #[derive(Clone)] struct AppState { tx: Sender, metrics: Arc, default_score_threshold: f32, } struct WorkItem { enqueued: Instant, input: Option, return_boxes: bool, score_threshold: f32, respond_to: oneshot::Sender>, } #[derive(Debug, Serialize, Clone)] struct BoxResult { label: String, class_id: i32, score: f32, bbox: [i32; 4], order: i32, source: String, } #[derive(Debug, Serialize, Clone)] struct PageResult { boxes: Vec, batch_size: usize, queue_wait_us: u128, infer_us: u128, } #[derive(Debug, Serialize)] struct MetricsResponse { requests: u64, batches: u64, pages: u64, errors: u64, avg_batch_size: f64, avg_queue_wait_us: f64, avg_infer_us_per_batch: f64, } #[derive(Debug, Deserialize)] struct InferRequest { return_boxes: Option, score_threshold: Option, } #[derive(Debug, Deserialize)] struct LayoutParams { score_threshold: Option, } #[derive(Debug, Deserialize)] struct TensorLayoutParams { width: u32, height: u32, original_width: Option, original_height: Option, score_threshold: Option, } #[derive(Debug, Deserialize)] struct BatchTensorLayoutParams { batch: usize, width: u32, height: u32, original_width: Option, original_height: Option, score_threshold: Option, } fn preprocess_rgb( image: image::RgbImage, width: u32, height: u32, _keep_original: bool, ) -> PageInput { let target = 800u32; let resized = image::imageops::resize( &image, target, target, image::imageops::FilterType::Triangle, ); let mut chw = vec![0.0f32; 3 * target as usize * target as usize]; let plane = (target * target) as usize; for y in 0..target as usize { for x in 0..target as usize { let px = resized.get_pixel(x as u32, y as u32).0; let idx = y * target as usize + x; chw[idx] = px[0] as f32 / 255.0; chw[plane + idx] = px[1] as f32 / 255.0; chw[2 * plane + idx] = px[2] as f32 / 255.0; } } PageInput { image: Arc::new(chw), im_shape: [target as f32, target as f32], scale_factor: [target as f32 / height as f32, target as f32 / width as f32], } } fn load_sample(path: Option<&str>) -> Result { if let Some(path) = path { let image = image::open(Path::new(path)) .with_context(|| format!("open sample image {path}"))? .to_rgb8(); let (width, height) = image.dimensions(); Ok(preprocess_rgb(image, width, height, false)) } else { let target = 800u32; let image = image::RgbImage::from_pixel(target, target, image::Rgb([255, 255, 255])); Ok(preprocess_rgb(image, target, target, false)) } } fn preprocess_image_bytes(bytes: &[u8], _keep_original: bool) -> Result { let image = image::load_from_memory(bytes)?.to_rgb8(); let (width, height) = image.dimensions(); Ok(preprocess_rgb(image, width, height, false)) } fn preprocess_model_rgb_bytes( bytes: &[u8], width: u32, height: u32, original_width: u32, original_height: u32, _keep_original: bool, ) -> Result { let expected = width as usize * height as usize * 3; if width != 800 || height != 800 { return Err(anyhow!( "model RGB input must be 800x800, got {width}x{height}" )); } if bytes.len() != expected { return Err(anyhow!( "RGB body has {} bytes, expected {} for {}x{}x3", bytes.len(), expected, width, height )); } let mut chw = vec![0.0f32; expected]; let plane = (width * height) as usize; for idx in 0..plane { let src = idx * 3; chw[idx] = bytes[src] as f32 / 255.0; chw[plane + idx] = bytes[src + 1] as f32 / 255.0; chw[2 * plane + idx] = bytes[src + 2] as f32 / 255.0; } Ok(PageInput { image: Arc::new(chw), im_shape: [height as f32, width as f32], scale_factor: [ height as f32 / original_height.max(1) as f32, width as f32 / original_width.max(1) as f32, ], }) } fn preprocess_model_chw_u8_bytes( bytes: &[u8], width: u32, height: u32, original_width: u32, original_height: u32, _keep_original: bool, ) -> Result { if width != 800 || height != 800 { return Err(anyhow!( "model CHW u8 input must be 800x800, got {width}x{height}" )); } let expected = width as usize * height as usize * 3; if bytes.len() != expected { return Err(anyhow!( "CHW u8 body has {} bytes, expected {} for 3x{}x{}", bytes.len(), expected, height, width )); } let mut chw = vec![0.0f32; expected]; for (dst, src) in chw.iter_mut().zip(bytes.iter()) { *dst = *src as f32 / 255.0; } Ok(PageInput { image: Arc::new(chw), im_shape: [height as f32, width as f32], scale_factor: [ height as f32 / original_height.max(1) as f32, width as f32 / original_width.max(1) as f32, ], }) } fn preprocess_model_chw_u8_batch_bytes( bytes: &[u8], batch: usize, width: u32, height: u32, original_width: u32, original_height: u32, _keep_original: bool, ) -> Result> { if batch == 0 { return Err(anyhow!("batch must be > 0")); } let per_page = width as usize * height as usize * 3; let expected = batch * per_page; if bytes.len() != expected { return Err(anyhow!( "batched CHW u8 body has {} bytes, expected {} for batch={} 3x{}x{}", bytes.len(), expected, batch, height, width )); } (0..batch) .into_iter() .map(|idx| { let start = idx * per_page; let end = start + per_page; preprocess_model_chw_u8_bytes( &bytes[start..end], width, height, original_width, original_height, false, ) }) .collect() } fn preprocess_model_chw_f32_bytes( bytes: &[u8], width: u32, height: u32, original_width: u32, original_height: u32, _keep_original: bool, ) -> Result { if width != 800 || height != 800 { return Err(anyhow!( "model CHW f32 input must be 800x800, got {width}x{height}" )); } let floats = width as usize * height as usize * 3; let expected = floats * std::mem::size_of::(); if bytes.len() != expected { return Err(anyhow!( "CHW f32 body has {} bytes, expected {} for 3x{}x{}", bytes.len(), expected, height, width )); } let mut chw = vec![0.0f32; floats]; unsafe { std::ptr::copy_nonoverlapping(bytes.as_ptr(), chw.as_mut_ptr() as *mut u8, expected); } Ok(PageInput { image: Arc::new(chw), im_shape: [height as f32, width as f32], scale_factor: [ height as f32 / original_height.max(1) as f32, width as f32 / original_width.max(1) as f32, ], }) } fn decode_page( boxes: &[f32], count: usize, return_boxes: bool, score_threshold: f32, ) -> Vec { if !return_boxes { return Vec::new(); } let mut out = Vec::new(); let limit = count.min(300); for i in 0..limit { let row = &boxes[i * 7..i * 7 + 7]; let class_id = row[0] as i32; let score = row[1]; if score < score_threshold || class_id < 0 || class_id as usize >= LABELS.len() { continue; } out.push(BoxResult { label: LABELS[class_id as usize].to_string(), class_id, score, bbox: [ row[2].round() as i32, row[3].round() as i32, row[4].round() as i32, row[5].round() as i32, ], order: row[6] as i32, source: "pp_doclayout_v3".to_string(), }); } out.sort_by_key(|b| b.order); out } fn push_input( dst_image: &mut Vec, dst_shape: &mut Vec, dst_scale: &mut Vec, input: &PageInput, ) { dst_image.extend_from_slice(&input.image); dst_shape.extend_from_slice(&input.im_shape); dst_scale.extend_from_slice(&input.scale_factor); } fn batch_worker( rx: Receiver, metrics: Arc, engine_path: String, sample: PageInput, max_batch_cfg: usize, max_delay: Duration, ) { let mut engine = Engine::load(&engine_path).expect("load TensorRT engine"); let max_batch = max_batch_cfg.min(engine.max_batch()).max(1); let mut cached_images: Vec> = Vec::with_capacity(max_batch + 1); let mut cached_shapes: Vec> = Vec::with_capacity(max_batch + 1); let mut cached_scales: Vec> = Vec::with_capacity(max_batch + 1); cached_images.push(Vec::new()); cached_shapes.push(Vec::new()); cached_scales.push(Vec::new()); for b in 1..=max_batch { let mut image = Vec::with_capacity(b * sample.image.len()); let mut im_shape = Vec::with_capacity(b * 2); let mut scale = Vec::with_capacity(b * 2); for _ in 0..b { push_input(&mut image, &mut im_shape, &mut scale, &sample); } cached_images.push(image); cached_shapes.push(im_shape); cached_scales.push(scale); } tracing::info!( engine_path, max_batch, delay_us = max_delay.as_micros(), "rust batcher ready" ); loop { let first = match rx.recv() { Ok(item) => item, Err(_) => break, }; let mut items = vec![first]; let deadline = Instant::now() + max_delay; while items.len() < max_batch { let now = Instant::now(); if now >= deadline { break; } match rx.recv_timeout(deadline - now) { Ok(item) => items.push(item), Err(_) => break, } } let batch = items.len(); let all_sample = items.iter().all(|item| item.input.is_none()); let mut image_owned = Vec::new(); let mut shape_owned = Vec::new(); let mut scale_owned = Vec::new(); let (image_ref, shape_ref, scale_ref) = if all_sample { ( &cached_images[batch], &cached_shapes[batch], &cached_scales[batch], ) } else { image_owned.reserve(batch * sample.image.len()); shape_owned.reserve(batch * 2); scale_owned.reserve(batch * 2); for item in &items { let input = item.input.as_ref().unwrap_or(&sample); push_input(&mut image_owned, &mut shape_owned, &mut scale_owned, input); } (&image_owned, &shape_owned, &scale_owned) }; let infer_start = Instant::now(); let infer_result = engine.infer(image_ref, shape_ref, scale_ref, batch); let infer_us = infer_start.elapsed().as_micros(); metrics.batches.fetch_add(1, Ordering::Relaxed); metrics.pages.fetch_add(batch as u64, Ordering::Relaxed); metrics .total_infer_us .fetch_add(infer_us as u64, Ordering::Relaxed); match infer_result { Ok((boxes, counts)) => { for (idx, item) in items.into_iter().enumerate() { let wait_us = infer_start.duration_since(item.enqueued).as_micros(); metrics .total_batch_wait_us .fetch_add(wait_us as u64, Ordering::Relaxed); let start = idx * 300 * 7; let end = start + 300 * 7; let decoded = decode_page( &boxes[start..end], counts[idx].max(0) as usize, item.return_boxes, item.score_threshold, ); let _ = item.respond_to.send(Ok(PageResult { boxes: decoded, batch_size: batch, queue_wait_us: wait_us, infer_us, })); } } Err(e) => { metrics.errors.fetch_add(batch as u64, Ordering::Relaxed); let msg = e.to_string(); for item in items { let _ = item.respond_to.send(Err(msg.clone())); } } } } } fn run_self_test(engine_path: &str, sample: &PageInput, batch: usize, iters: usize) -> Result<()> { let mut engine = Engine::load(engine_path)?; let max_batch = batch.min(engine.max_batch()).max(1); let mut image = Vec::with_capacity(max_batch * sample.image.len()); let mut im_shape = Vec::with_capacity(max_batch * 2); let mut scale = Vec::with_capacity(max_batch * 2); for _ in 0..max_batch { push_input(&mut image, &mut im_shape, &mut scale, sample); } for _ in 0..iters { let (_boxes, counts) = engine.infer(&image, &im_shape, &scale, max_batch)?; if counts.len() != max_batch { return Err(anyhow!( "unexpected count length {} for batch {}", counts.len(), max_batch )); } } println!("self_test_ok batch={max_batch} iters={iters}"); Ok(()) } async fn health() -> impl IntoResponse { Json(serde_json::json!({"status":"ok"})) } async fn metrics_handler(State(state): State) -> impl IntoResponse { let requests = state.metrics.requests.load(Ordering::Relaxed); let batches = state.metrics.batches.load(Ordering::Relaxed); let pages = state.metrics.pages.load(Ordering::Relaxed); let errors = state.metrics.errors.load(Ordering::Relaxed); let wait = state.metrics.total_batch_wait_us.load(Ordering::Relaxed); let infer = state.metrics.total_infer_us.load(Ordering::Relaxed); Json(MetricsResponse { requests, batches, pages, errors, avg_batch_size: if batches == 0 { 0.0 } else { pages as f64 / batches as f64 }, avg_queue_wait_us: if requests == 0 { 0.0 } else { wait as f64 / requests as f64 }, avg_infer_us_per_batch: if batches == 0 { 0.0 } else { infer as f64 / batches as f64 }, }) } fn score_threshold_or_default(value: Option, default: f32) -> f32 { value.unwrap_or(default).clamp(0.0, 1.0) } async fn enqueue_pages( state: &AppState, inputs: Vec, return_boxes: bool, score_threshold: f32, ) -> Result, (StatusCode, Json)> { let mut receivers = Vec::with_capacity(inputs.len()); for input in inputs { let (tx, rx) = oneshot::channel(); state.metrics.requests.fetch_add(1, Ordering::Relaxed); if state .tx .send(WorkItem { enqueued: Instant::now(), input: Some(input), return_boxes, score_threshold, respond_to: tx, }) .is_err() { state.metrics.errors.fetch_add(1, Ordering::Relaxed); return Err(( StatusCode::SERVICE_UNAVAILABLE, Json(serde_json::json!({"error":"batch worker stopped"})), )); } receivers.push(rx); } let mut results = Vec::with_capacity(receivers.len()); for rx in receivers { match rx.await { Ok(Ok(result)) => results.push(result), Ok(Err(err)) => { return Err(( StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error":err})), )); } Err(_) => { return Err(( StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error":"response channel closed"})), )); } } } Ok(results) } async fn enqueue_page( state: &AppState, input: Option, return_boxes: bool, score_threshold: f32, ) -> Result)> { let (tx, rx) = oneshot::channel(); state.metrics.requests.fetch_add(1, Ordering::Relaxed); if state .tx .send(WorkItem { enqueued: Instant::now(), input, return_boxes, score_threshold, respond_to: tx, }) .is_err() { state.metrics.errors.fetch_add(1, Ordering::Relaxed); return Err(( StatusCode::SERVICE_UNAVAILABLE, Json(serde_json::json!({"error":"batch worker stopped"})), )); } match rx.await { Ok(Ok(result)) => Ok(result), Ok(Err(err)) => Err(( StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error":err})), )), Err(_) => Err(( StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error":"response channel closed"})), )), } } async fn infer(State(state): State, Json(req): Json) -> impl IntoResponse { let score_threshold = score_threshold_or_default(req.score_threshold, state.default_score_threshold); match enqueue_page( &state, None, req.return_boxes.unwrap_or(true), score_threshold, ) .await { Ok(result) => Json(serde_json::json!({"pages":1,"results":[result]})).into_response(), Err(err) => err.into_response(), } } async fn layout( State(state): State, Query(params): Query, mut multipart: Multipart, ) -> impl IntoResponse { let score_threshold = score_threshold_or_default(params.score_threshold, state.default_score_threshold); let mut preprocess_tasks = Vec::new(); while let Some(field) = match multipart.next_field().await { Ok(field) => field, Err(err) => { return ( StatusCode::BAD_REQUEST, Json(serde_json::json!({"error":err.to_string()})), ) .into_response(); } } { let Some(name) = field.name().map(str::to_string) else { continue; }; if name != "file" && name != "files" { continue; } let bytes = match field.bytes().await { Ok(bytes) => bytes, Err(err) => { return ( StatusCode::BAD_REQUEST, Json(serde_json::json!({"error":err.to_string()})), ) .into_response(); } }; preprocess_tasks.push(task::spawn_blocking(move || { preprocess_image_bytes(&bytes, false) })); } if preprocess_tasks.is_empty() { return ( StatusCode::BAD_REQUEST, Json(serde_json::json!({"error":"multipart field 'file' or 'files' is required"})), ) .into_response(); } let mut inputs = Vec::with_capacity(preprocess_tasks.len()); for handle in preprocess_tasks { match handle.await { Ok(Ok(input)) => inputs.push(input), Ok(Err(err)) => { return ( StatusCode::BAD_REQUEST, Json(serde_json::json!({"error":err.to_string()})), ) .into_response(); } Err(err) => { return ( StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error":format!("preprocess worker failed: {err}")})), ) .into_response(); } } } let mut results = Vec::with_capacity(inputs.len()); for input in inputs { match enqueue_page(&state, Some(input.clone()), true, score_threshold).await { Ok(result) => results.push(result), Err(err) => return err.into_response(), } } Json(serde_json::json!({"pages":results.len(),"results":results})).into_response() } async fn layout_tensor( State(state): State, Query(params): Query, body: Bytes, ) -> impl IntoResponse { let original_width = params.original_width.unwrap_or(params.width); let original_height = params.original_height.unwrap_or(params.height); let input = match task::spawn_blocking(move || { preprocess_model_rgb_bytes( &body, params.width, params.height, original_width, original_height, false, ) }) .await { Ok(Ok(input)) => input, Ok(Err(err)) => { return ( StatusCode::BAD_REQUEST, Json(serde_json::json!({"error":err.to_string()})), ) .into_response(); } Err(err) => { return ( StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error":format!("preprocess worker failed: {err}")})), ) .into_response(); } }; let score_threshold = score_threshold_or_default(params.score_threshold, state.default_score_threshold); match enqueue_page(&state, Some(input.clone()), true, score_threshold).await { Ok(result) => Json(serde_json::json!({"pages":1,"results":[result]})).into_response(), Err(err) => err.into_response(), } } async fn layout_chw_u8( State(state): State, Query(params): Query, body: Bytes, ) -> impl IntoResponse { let original_width = params.original_width.unwrap_or(params.width); let original_height = params.original_height.unwrap_or(params.height); let input = match task::spawn_blocking(move || { preprocess_model_chw_u8_bytes( &body, params.width, params.height, original_width, original_height, false, ) }) .await { Ok(Ok(input)) => input, Ok(Err(err)) => { return ( StatusCode::BAD_REQUEST, Json(serde_json::json!({"error":err.to_string()})), ) .into_response(); } Err(err) => { return ( StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error":format!("preprocess worker failed: {err}")})), ) .into_response(); } }; let score_threshold = score_threshold_or_default(params.score_threshold, state.default_score_threshold); match enqueue_page(&state, Some(input.clone()), true, score_threshold).await { Ok(result) => Json(serde_json::json!({"pages":1,"results":[result]})).into_response(), Err(err) => err.into_response(), } } async fn layout_chw_u8_batch( State(state): State, Query(params): Query, body: Bytes, ) -> impl IntoResponse { let original_width = params.original_width.unwrap_or(params.width); let original_height = params.original_height.unwrap_or(params.height); let batch = params.batch; let inputs = match task::spawn_blocking(move || { preprocess_model_chw_u8_batch_bytes( &body, batch, params.width, params.height, original_width, original_height, false, ) }) .await { Ok(Ok(inputs)) => inputs, Ok(Err(err)) => { return ( StatusCode::BAD_REQUEST, Json(serde_json::json!({"error":err.to_string()})), ) .into_response(); } Err(err) => { return ( StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error":format!("preprocess worker failed: {err}")})), ) .into_response(); } }; let score_threshold = score_threshold_or_default(params.score_threshold, state.default_score_threshold); match enqueue_pages(&state, inputs.clone(), true, score_threshold).await { Ok(results) => { Json(serde_json::json!({"pages":results.len(),"results":results})).into_response() } Err(err) => err.into_response(), } } async fn layout_chw_f32( State(state): State, Query(params): Query, body: Bytes, ) -> impl IntoResponse { let original_width = params.original_width.unwrap_or(params.width); let original_height = params.original_height.unwrap_or(params.height); let input = match task::spawn_blocking(move || { preprocess_model_chw_f32_bytes( &body, params.width, params.height, original_width, original_height, false, ) }) .await { Ok(Ok(input)) => input, Ok(Err(err)) => { return ( StatusCode::BAD_REQUEST, Json(serde_json::json!({"error":err.to_string()})), ) .into_response(); } Err(err) => { return ( StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error":format!("preprocess worker failed: {err}")})), ) .into_response(); } }; let score_threshold = score_threshold_or_default(params.score_threshold, state.default_score_threshold); match enqueue_page(&state, Some(input.clone()), true, score_threshold).await { Ok(result) => Json(serde_json::json!({"pages":1,"results":[result]})).into_response(), Err(err) => err.into_response(), } } #[tokio::main] async fn main() -> Result<()> { tracing_subscriber::fmt() .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) .init(); let engine_path = env::var("DOC_LAYOUT_ENGINE").context("DOC_LAYOUT_ENGINE is required")?; let sample_path = env::var("DOC_LAYOUT_SAMPLE_IMAGE").ok(); let max_batch = env::var("DOC_LAYOUT_MAX_BATCH") .ok() .and_then(|v| v.parse().ok()) .unwrap_or(8); let max_delay_us = env::var("DOC_LAYOUT_MAX_DELAY_US") .ok() .and_then(|v| v.parse().ok()) .unwrap_or(1000u64); let queue_capacity = env::var("DOC_LAYOUT_QUEUE_CAPACITY") .ok() .and_then(|v| v.parse().ok()) .unwrap_or(4096usize); let default_score_threshold = env::var("DOC_LAYOUT_SCORE_THRESHOLD") .ok() .and_then(|v| v.parse::().ok()) .unwrap_or(0.5) .clamp(0.0, 1.0); let max_upload_mb = env::var("DOC_LAYOUT_MAX_UPLOAD_MB") .ok() .and_then(|v| v.parse::().ok()) .unwrap_or(512); let workers = env::var("DOC_LAYOUT_WORKERS") .ok() .and_then(|v| v.parse().ok()) .unwrap_or(1usize) .max(1); let port = env::var("DOC_LAYOUT_PORT") .ok() .and_then(|v| v.parse().ok()) .unwrap_or(18082u16); let sample = load_sample(sample_path.as_deref())?; if let Some(iters) = env::var("DOC_LAYOUT_SELF_TEST_ITERS") .ok() .and_then(|v| v.parse::().ok()) { let batch = env::var("DOC_LAYOUT_SELF_TEST_BATCH") .ok() .and_then(|v| v.parse::().ok()) .unwrap_or(max_batch); return run_self_test(&engine_path, &sample, batch, iters); } let (tx, rx) = bounded::(queue_capacity); let metrics = Arc::new(Metrics::default()); for worker_id in 0..workers { let worker_rx = rx.clone(); let worker_metrics = metrics.clone(); let worker_engine_path = engine_path.clone(); let worker_sample = sample.clone(); thread::spawn(move || { tracing::info!(worker_id, "starting batch worker"); batch_worker( worker_rx, worker_metrics, worker_engine_path, worker_sample, max_batch, Duration::from_micros(max_delay_us), ); }); } let state = AppState { tx, metrics, default_score_threshold, }; let app = Router::new() .route("/health", get(health)) .route("/metrics", get(metrics_handler)) .route("/v1/infer", post(infer)) .route("/v1/layout", post(layout)) .route("/v1/layout_tensor", post(layout_tensor)) .route("/v1/layout_chw_f32", post(layout_chw_f32)) .route("/v1/layout_chw_u8", post(layout_chw_u8)) .route("/v1/layout_chw_u8_batch", post(layout_chw_u8_batch)) .layer(DefaultBodyLimit::max(max_upload_mb * 1024 * 1024)) .with_state(state); let addr = SocketAddr::from(([0, 0, 0, 0], port)); tracing::info!( %addr, workers, max_batch, max_delay_us, max_upload_mb, default_score_threshold, "starting doclayout TensorRT batcher" ); let listener = tokio::net::TcpListener::bind(addr).await?; axum::serve(listener, app).await?; Ok(()) }