Instructions to use bndos/pp-doclayout-v3-trt with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- TensorRT
How to use bndos/pp-doclayout-v3-trt with TensorRT:
# No code snippets available yet for this library. # To use this model, check the repository files and the library's documentation. # Want to help? PRs adding snippets are welcome at: # https://github.com/huggingface/huggingface.js
- Notebooks
- Google Colab
- Kaggle
| use anyhow::{anyhow, Context, Result}; | |
| use axum::{ | |
| body::Bytes, | |
| extract::{DefaultBodyLimit, Multipart, Query, State}, | |
| http::StatusCode, | |
| response::IntoResponse, | |
| routing::{get, post}, | |
| Json, Router, | |
| }; | |
| use crossbeam_channel::{bounded, Receiver, Sender}; | |
| use serde::{Deserialize, Serialize}; | |
| use std::{ | |
| env, | |
| ffi::CString, | |
| net::SocketAddr, | |
| os::raw::{c_char, c_float, c_int}, | |
| path::Path, | |
| sync::{ | |
| atomic::{AtomicU64, Ordering}, | |
| Arc, | |
| }, | |
| thread, | |
| time::{Duration, Instant}, | |
| }; | |
| use tokio::{sync::oneshot, task}; | |
| struct TrtContextOpaque { | |
| _private: [u8; 0], | |
| } | |
| extern "C" { | |
| fn trt_create(engine_path: *const c_char) -> *mut TrtContextOpaque; | |
| fn trt_destroy(ctx: *mut TrtContextOpaque); | |
| fn trt_max_batch(ctx: *mut TrtContextOpaque) -> c_int; | |
| fn trt_infer( | |
| ctx: *mut TrtContextOpaque, | |
| image: *const c_float, | |
| im_shape: *const c_float, | |
| scale_factor: *const c_float, | |
| batch: c_int, | |
| boxes_out: *mut c_float, | |
| counts_out: *mut i32, | |
| ) -> c_int; | |
| } | |
| const LABELS: [&str; 25] = [ | |
| "abstract", | |
| "algorithm", | |
| "aside_text", | |
| "chart", | |
| "content", | |
| "display_formula", | |
| "doc_title", | |
| "figure_title", | |
| "footer", | |
| "footer_image", | |
| "footnote", | |
| "formula_number", | |
| "header", | |
| "header_image", | |
| "image", | |
| "inline_formula", | |
| "number", | |
| "paragraph_title", | |
| "reference", | |
| "reference_content", | |
| "seal", | |
| "table", | |
| "text", | |
| "vertical_text", | |
| "vision_footnote", | |
| ]; | |
| struct Engine(*mut TrtContextOpaque); | |
| unsafe impl Send for Engine {} | |
| impl Engine { | |
| fn load(path: &str) -> Result<Self> { | |
| let c_path = CString::new(path)?; | |
| let ptr = unsafe { trt_create(c_path.as_ptr()) }; | |
| if ptr.is_null() { | |
| return Err(anyhow!("failed to load TensorRT engine: {path}")); | |
| } | |
| Ok(Self(ptr)) | |
| } | |
| fn max_batch(&self) -> usize { | |
| unsafe { trt_max_batch(self.0).max(1) as usize } | |
| } | |
| fn infer( | |
| &mut self, | |
| image: &[f32], | |
| im_shape: &[f32], | |
| scale_factor: &[f32], | |
| batch: usize, | |
| ) -> Result<(Vec<f32>, Vec<i32>)> { | |
| let mut boxes = vec![0.0f32; batch * 300 * 7]; | |
| let mut counts = vec![0i32; batch]; | |
| let code = unsafe { | |
| trt_infer( | |
| self.0, | |
| image.as_ptr(), | |
| im_shape.as_ptr(), | |
| scale_factor.as_ptr(), | |
| batch as c_int, | |
| boxes.as_mut_ptr(), | |
| counts.as_mut_ptr(), | |
| ) | |
| }; | |
| if code != 0 { | |
| return Err(anyhow!("trt_infer failed with code {code}")); | |
| } | |
| Ok((boxes, counts)) | |
| } | |
| } | |
| impl Drop for Engine { | |
| fn drop(&mut self) { | |
| unsafe { trt_destroy(self.0) } | |
| } | |
| } | |
| struct PageInput { | |
| image: Arc<Vec<f32>>, | |
| im_shape: [f32; 2], | |
| scale_factor: [f32; 2], | |
| } | |
| struct Metrics { | |
| requests: AtomicU64, | |
| batches: AtomicU64, | |
| pages: AtomicU64, | |
| errors: AtomicU64, | |
| total_batch_wait_us: AtomicU64, | |
| total_infer_us: AtomicU64, | |
| } | |
| struct AppState { | |
| tx: Sender<WorkItem>, | |
| metrics: Arc<Metrics>, | |
| default_score_threshold: f32, | |
| } | |
| struct WorkItem { | |
| enqueued: Instant, | |
| input: Option<PageInput>, | |
| return_boxes: bool, | |
| score_threshold: f32, | |
| respond_to: oneshot::Sender<Result<PageResult, String>>, | |
| } | |
| struct BoxResult { | |
| label: String, | |
| class_id: i32, | |
| score: f32, | |
| bbox: [i32; 4], | |
| order: i32, | |
| source: String, | |
| } | |
| struct PageResult { | |
| boxes: Vec<BoxResult>, | |
| batch_size: usize, | |
| queue_wait_us: u128, | |
| infer_us: u128, | |
| } | |
| struct MetricsResponse { | |
| requests: u64, | |
| batches: u64, | |
| pages: u64, | |
| errors: u64, | |
| avg_batch_size: f64, | |
| avg_queue_wait_us: f64, | |
| avg_infer_us_per_batch: f64, | |
| } | |
| struct InferRequest { | |
| return_boxes: Option<bool>, | |
| score_threshold: Option<f32>, | |
| } | |
| struct LayoutParams { | |
| score_threshold: Option<f32>, | |
| } | |
| struct TensorLayoutParams { | |
| width: u32, | |
| height: u32, | |
| original_width: Option<u32>, | |
| original_height: Option<u32>, | |
| score_threshold: Option<f32>, | |
| } | |
| struct BatchTensorLayoutParams { | |
| batch: usize, | |
| width: u32, | |
| height: u32, | |
| original_width: Option<u32>, | |
| original_height: Option<u32>, | |
| score_threshold: Option<f32>, | |
| } | |
| fn preprocess_rgb( | |
| image: image::RgbImage, | |
| width: u32, | |
| height: u32, | |
| _keep_original: bool, | |
| ) -> PageInput { | |
| let target = 800u32; | |
| let resized = image::imageops::resize( | |
| &image, | |
| target, | |
| target, | |
| image::imageops::FilterType::Triangle, | |
| ); | |
| let mut chw = vec![0.0f32; 3 * target as usize * target as usize]; | |
| let plane = (target * target) as usize; | |
| for y in 0..target as usize { | |
| for x in 0..target as usize { | |
| let px = resized.get_pixel(x as u32, y as u32).0; | |
| let idx = y * target as usize + x; | |
| chw[idx] = px[0] as f32 / 255.0; | |
| chw[plane + idx] = px[1] as f32 / 255.0; | |
| chw[2 * plane + idx] = px[2] as f32 / 255.0; | |
| } | |
| } | |
| PageInput { | |
| image: Arc::new(chw), | |
| im_shape: [target as f32, target as f32], | |
| scale_factor: [target as f32 / height as f32, target as f32 / width as f32], | |
| } | |
| } | |
| fn load_sample(path: Option<&str>) -> Result<PageInput> { | |
| if let Some(path) = path { | |
| let image = image::open(Path::new(path)) | |
| .with_context(|| format!("open sample image {path}"))? | |
| .to_rgb8(); | |
| let (width, height) = image.dimensions(); | |
| Ok(preprocess_rgb(image, width, height, false)) | |
| } else { | |
| let target = 800u32; | |
| let image = image::RgbImage::from_pixel(target, target, image::Rgb([255, 255, 255])); | |
| Ok(preprocess_rgb(image, target, target, false)) | |
| } | |
| } | |
| fn preprocess_image_bytes(bytes: &[u8], _keep_original: bool) -> Result<PageInput> { | |
| let image = image::load_from_memory(bytes)?.to_rgb8(); | |
| let (width, height) = image.dimensions(); | |
| Ok(preprocess_rgb(image, width, height, false)) | |
| } | |
| fn preprocess_model_rgb_bytes( | |
| bytes: &[u8], | |
| width: u32, | |
| height: u32, | |
| original_width: u32, | |
| original_height: u32, | |
| _keep_original: bool, | |
| ) -> Result<PageInput> { | |
| let expected = width as usize * height as usize * 3; | |
| if width != 800 || height != 800 { | |
| return Err(anyhow!( | |
| "model RGB input must be 800x800, got {width}x{height}" | |
| )); | |
| } | |
| if bytes.len() != expected { | |
| return Err(anyhow!( | |
| "RGB body has {} bytes, expected {} for {}x{}x3", | |
| bytes.len(), | |
| expected, | |
| width, | |
| height | |
| )); | |
| } | |
| let mut chw = vec![0.0f32; expected]; | |
| let plane = (width * height) as usize; | |
| for idx in 0..plane { | |
| let src = idx * 3; | |
| chw[idx] = bytes[src] as f32 / 255.0; | |
| chw[plane + idx] = bytes[src + 1] as f32 / 255.0; | |
| chw[2 * plane + idx] = bytes[src + 2] as f32 / 255.0; | |
| } | |
| Ok(PageInput { | |
| image: Arc::new(chw), | |
| im_shape: [height as f32, width as f32], | |
| scale_factor: [ | |
| height as f32 / original_height.max(1) as f32, | |
| width as f32 / original_width.max(1) as f32, | |
| ], | |
| }) | |
| } | |
| fn preprocess_model_chw_u8_bytes( | |
| bytes: &[u8], | |
| width: u32, | |
| height: u32, | |
| original_width: u32, | |
| original_height: u32, | |
| _keep_original: bool, | |
| ) -> Result<PageInput> { | |
| if width != 800 || height != 800 { | |
| return Err(anyhow!( | |
| "model CHW u8 input must be 800x800, got {width}x{height}" | |
| )); | |
| } | |
| let expected = width as usize * height as usize * 3; | |
| if bytes.len() != expected { | |
| return Err(anyhow!( | |
| "CHW u8 body has {} bytes, expected {} for 3x{}x{}", | |
| bytes.len(), | |
| expected, | |
| height, | |
| width | |
| )); | |
| } | |
| let mut chw = vec![0.0f32; expected]; | |
| for (dst, src) in chw.iter_mut().zip(bytes.iter()) { | |
| *dst = *src as f32 / 255.0; | |
| } | |
| Ok(PageInput { | |
| image: Arc::new(chw), | |
| im_shape: [height as f32, width as f32], | |
| scale_factor: [ | |
| height as f32 / original_height.max(1) as f32, | |
| width as f32 / original_width.max(1) as f32, | |
| ], | |
| }) | |
| } | |
| fn preprocess_model_chw_u8_batch_bytes( | |
| bytes: &[u8], | |
| batch: usize, | |
| width: u32, | |
| height: u32, | |
| original_width: u32, | |
| original_height: u32, | |
| _keep_original: bool, | |
| ) -> Result<Vec<PageInput>> { | |
| if batch == 0 { | |
| return Err(anyhow!("batch must be > 0")); | |
| } | |
| let per_page = width as usize * height as usize * 3; | |
| let expected = batch * per_page; | |
| if bytes.len() != expected { | |
| return Err(anyhow!( | |
| "batched CHW u8 body has {} bytes, expected {} for batch={} 3x{}x{}", | |
| bytes.len(), | |
| expected, | |
| batch, | |
| height, | |
| width | |
| )); | |
| } | |
| (0..batch) | |
| .into_iter() | |
| .map(|idx| { | |
| let start = idx * per_page; | |
| let end = start + per_page; | |
| preprocess_model_chw_u8_bytes( | |
| &bytes[start..end], | |
| width, | |
| height, | |
| original_width, | |
| original_height, | |
| false, | |
| ) | |
| }) | |
| .collect() | |
| } | |
| fn preprocess_model_chw_f32_bytes( | |
| bytes: &[u8], | |
| width: u32, | |
| height: u32, | |
| original_width: u32, | |
| original_height: u32, | |
| _keep_original: bool, | |
| ) -> Result<PageInput> { | |
| if width != 800 || height != 800 { | |
| return Err(anyhow!( | |
| "model CHW f32 input must be 800x800, got {width}x{height}" | |
| )); | |
| } | |
| let floats = width as usize * height as usize * 3; | |
| let expected = floats * std::mem::size_of::<f32>(); | |
| if bytes.len() != expected { | |
| return Err(anyhow!( | |
| "CHW f32 body has {} bytes, expected {} for 3x{}x{}", | |
| bytes.len(), | |
| expected, | |
| height, | |
| width | |
| )); | |
| } | |
| let mut chw = vec![0.0f32; floats]; | |
| unsafe { | |
| std::ptr::copy_nonoverlapping(bytes.as_ptr(), chw.as_mut_ptr() as *mut u8, expected); | |
| } | |
| Ok(PageInput { | |
| image: Arc::new(chw), | |
| im_shape: [height as f32, width as f32], | |
| scale_factor: [ | |
| height as f32 / original_height.max(1) as f32, | |
| width as f32 / original_width.max(1) as f32, | |
| ], | |
| }) | |
| } | |
| fn decode_page( | |
| boxes: &[f32], | |
| count: usize, | |
| return_boxes: bool, | |
| score_threshold: f32, | |
| ) -> Vec<BoxResult> { | |
| if !return_boxes { | |
| return Vec::new(); | |
| } | |
| let mut out = Vec::new(); | |
| let limit = count.min(300); | |
| for i in 0..limit { | |
| let row = &boxes[i * 7..i * 7 + 7]; | |
| let class_id = row[0] as i32; | |
| let score = row[1]; | |
| if score < score_threshold || class_id < 0 || class_id as usize >= LABELS.len() { | |
| continue; | |
| } | |
| out.push(BoxResult { | |
| label: LABELS[class_id as usize].to_string(), | |
| class_id, | |
| score, | |
| bbox: [ | |
| row[2].round() as i32, | |
| row[3].round() as i32, | |
| row[4].round() as i32, | |
| row[5].round() as i32, | |
| ], | |
| order: row[6] as i32, | |
| source: "pp_doclayout_v3".to_string(), | |
| }); | |
| } | |
| out.sort_by_key(|b| b.order); | |
| out | |
| } | |
| fn push_input( | |
| dst_image: &mut Vec<f32>, | |
| dst_shape: &mut Vec<f32>, | |
| dst_scale: &mut Vec<f32>, | |
| input: &PageInput, | |
| ) { | |
| dst_image.extend_from_slice(&input.image); | |
| dst_shape.extend_from_slice(&input.im_shape); | |
| dst_scale.extend_from_slice(&input.scale_factor); | |
| } | |
| fn batch_worker( | |
| rx: Receiver<WorkItem>, | |
| metrics: Arc<Metrics>, | |
| engine_path: String, | |
| sample: PageInput, | |
| max_batch_cfg: usize, | |
| max_delay: Duration, | |
| ) { | |
| let mut engine = Engine::load(&engine_path).expect("load TensorRT engine"); | |
| let max_batch = max_batch_cfg.min(engine.max_batch()).max(1); | |
| let mut cached_images: Vec<Vec<f32>> = Vec::with_capacity(max_batch + 1); | |
| let mut cached_shapes: Vec<Vec<f32>> = Vec::with_capacity(max_batch + 1); | |
| let mut cached_scales: Vec<Vec<f32>> = Vec::with_capacity(max_batch + 1); | |
| cached_images.push(Vec::new()); | |
| cached_shapes.push(Vec::new()); | |
| cached_scales.push(Vec::new()); | |
| for b in 1..=max_batch { | |
| let mut image = Vec::with_capacity(b * sample.image.len()); | |
| let mut im_shape = Vec::with_capacity(b * 2); | |
| let mut scale = Vec::with_capacity(b * 2); | |
| for _ in 0..b { | |
| push_input(&mut image, &mut im_shape, &mut scale, &sample); | |
| } | |
| cached_images.push(image); | |
| cached_shapes.push(im_shape); | |
| cached_scales.push(scale); | |
| } | |
| tracing::info!( | |
| engine_path, | |
| max_batch, | |
| delay_us = max_delay.as_micros(), | |
| "rust batcher ready" | |
| ); | |
| loop { | |
| let first = match rx.recv() { | |
| Ok(item) => item, | |
| Err(_) => break, | |
| }; | |
| let mut items = vec![first]; | |
| let deadline = Instant::now() + max_delay; | |
| while items.len() < max_batch { | |
| let now = Instant::now(); | |
| if now >= deadline { | |
| break; | |
| } | |
| match rx.recv_timeout(deadline - now) { | |
| Ok(item) => items.push(item), | |
| Err(_) => break, | |
| } | |
| } | |
| let batch = items.len(); | |
| let all_sample = items.iter().all(|item| item.input.is_none()); | |
| let mut image_owned = Vec::new(); | |
| let mut shape_owned = Vec::new(); | |
| let mut scale_owned = Vec::new(); | |
| let (image_ref, shape_ref, scale_ref) = if all_sample { | |
| ( | |
| &cached_images[batch], | |
| &cached_shapes[batch], | |
| &cached_scales[batch], | |
| ) | |
| } else { | |
| image_owned.reserve(batch * sample.image.len()); | |
| shape_owned.reserve(batch * 2); | |
| scale_owned.reserve(batch * 2); | |
| for item in &items { | |
| let input = item.input.as_ref().unwrap_or(&sample); | |
| push_input(&mut image_owned, &mut shape_owned, &mut scale_owned, input); | |
| } | |
| (&image_owned, &shape_owned, &scale_owned) | |
| }; | |
| let infer_start = Instant::now(); | |
| let infer_result = engine.infer(image_ref, shape_ref, scale_ref, batch); | |
| let infer_us = infer_start.elapsed().as_micros(); | |
| metrics.batches.fetch_add(1, Ordering::Relaxed); | |
| metrics.pages.fetch_add(batch as u64, Ordering::Relaxed); | |
| metrics | |
| .total_infer_us | |
| .fetch_add(infer_us as u64, Ordering::Relaxed); | |
| match infer_result { | |
| Ok((boxes, counts)) => { | |
| for (idx, item) in items.into_iter().enumerate() { | |
| let wait_us = infer_start.duration_since(item.enqueued).as_micros(); | |
| metrics | |
| .total_batch_wait_us | |
| .fetch_add(wait_us as u64, Ordering::Relaxed); | |
| let start = idx * 300 * 7; | |
| let end = start + 300 * 7; | |
| let decoded = decode_page( | |
| &boxes[start..end], | |
| counts[idx].max(0) as usize, | |
| item.return_boxes, | |
| item.score_threshold, | |
| ); | |
| let _ = item.respond_to.send(Ok(PageResult { | |
| boxes: decoded, | |
| batch_size: batch, | |
| queue_wait_us: wait_us, | |
| infer_us, | |
| })); | |
| } | |
| } | |
| Err(e) => { | |
| metrics.errors.fetch_add(batch as u64, Ordering::Relaxed); | |
| let msg = e.to_string(); | |
| for item in items { | |
| let _ = item.respond_to.send(Err(msg.clone())); | |
| } | |
| } | |
| } | |
| } | |
| } | |
| fn run_self_test(engine_path: &str, sample: &PageInput, batch: usize, iters: usize) -> Result<()> { | |
| let mut engine = Engine::load(engine_path)?; | |
| let max_batch = batch.min(engine.max_batch()).max(1); | |
| let mut image = Vec::with_capacity(max_batch * sample.image.len()); | |
| let mut im_shape = Vec::with_capacity(max_batch * 2); | |
| let mut scale = Vec::with_capacity(max_batch * 2); | |
| for _ in 0..max_batch { | |
| push_input(&mut image, &mut im_shape, &mut scale, sample); | |
| } | |
| for _ in 0..iters { | |
| let (_boxes, counts) = engine.infer(&image, &im_shape, &scale, max_batch)?; | |
| if counts.len() != max_batch { | |
| return Err(anyhow!( | |
| "unexpected count length {} for batch {}", | |
| counts.len(), | |
| max_batch | |
| )); | |
| } | |
| } | |
| println!("self_test_ok batch={max_batch} iters={iters}"); | |
| Ok(()) | |
| } | |
| async fn health() -> impl IntoResponse { | |
| Json(serde_json::json!({"status":"ok"})) | |
| } | |
| async fn metrics_handler(State(state): State<AppState>) -> impl IntoResponse { | |
| let requests = state.metrics.requests.load(Ordering::Relaxed); | |
| let batches = state.metrics.batches.load(Ordering::Relaxed); | |
| let pages = state.metrics.pages.load(Ordering::Relaxed); | |
| let errors = state.metrics.errors.load(Ordering::Relaxed); | |
| let wait = state.metrics.total_batch_wait_us.load(Ordering::Relaxed); | |
| let infer = state.metrics.total_infer_us.load(Ordering::Relaxed); | |
| Json(MetricsResponse { | |
| requests, | |
| batches, | |
| pages, | |
| errors, | |
| avg_batch_size: if batches == 0 { | |
| 0.0 | |
| } else { | |
| pages as f64 / batches as f64 | |
| }, | |
| avg_queue_wait_us: if requests == 0 { | |
| 0.0 | |
| } else { | |
| wait as f64 / requests as f64 | |
| }, | |
| avg_infer_us_per_batch: if batches == 0 { | |
| 0.0 | |
| } else { | |
| infer as f64 / batches as f64 | |
| }, | |
| }) | |
| } | |
| fn score_threshold_or_default(value: Option<f32>, default: f32) -> f32 { | |
| value.unwrap_or(default).clamp(0.0, 1.0) | |
| } | |
| async fn enqueue_pages( | |
| state: &AppState, | |
| inputs: Vec<PageInput>, | |
| return_boxes: bool, | |
| score_threshold: f32, | |
| ) -> Result<Vec<PageResult>, (StatusCode, Json<serde_json::Value>)> { | |
| let mut receivers = Vec::with_capacity(inputs.len()); | |
| for input in inputs { | |
| let (tx, rx) = oneshot::channel(); | |
| state.metrics.requests.fetch_add(1, Ordering::Relaxed); | |
| if state | |
| .tx | |
| .send(WorkItem { | |
| enqueued: Instant::now(), | |
| input: Some(input), | |
| return_boxes, | |
| score_threshold, | |
| respond_to: tx, | |
| }) | |
| .is_err() | |
| { | |
| state.metrics.errors.fetch_add(1, Ordering::Relaxed); | |
| return Err(( | |
| StatusCode::SERVICE_UNAVAILABLE, | |
| Json(serde_json::json!({"error":"batch worker stopped"})), | |
| )); | |
| } | |
| receivers.push(rx); | |
| } | |
| let mut results = Vec::with_capacity(receivers.len()); | |
| for rx in receivers { | |
| match rx.await { | |
| Ok(Ok(result)) => results.push(result), | |
| Ok(Err(err)) => { | |
| return Err(( | |
| StatusCode::INTERNAL_SERVER_ERROR, | |
| Json(serde_json::json!({"error":err})), | |
| )); | |
| } | |
| Err(_) => { | |
| return Err(( | |
| StatusCode::INTERNAL_SERVER_ERROR, | |
| Json(serde_json::json!({"error":"response channel closed"})), | |
| )); | |
| } | |
| } | |
| } | |
| Ok(results) | |
| } | |
| async fn enqueue_page( | |
| state: &AppState, | |
| input: Option<PageInput>, | |
| return_boxes: bool, | |
| score_threshold: f32, | |
| ) -> Result<PageResult, (StatusCode, Json<serde_json::Value>)> { | |
| let (tx, rx) = oneshot::channel(); | |
| state.metrics.requests.fetch_add(1, Ordering::Relaxed); | |
| if state | |
| .tx | |
| .send(WorkItem { | |
| enqueued: Instant::now(), | |
| input, | |
| return_boxes, | |
| score_threshold, | |
| respond_to: tx, | |
| }) | |
| .is_err() | |
| { | |
| state.metrics.errors.fetch_add(1, Ordering::Relaxed); | |
| return Err(( | |
| StatusCode::SERVICE_UNAVAILABLE, | |
| Json(serde_json::json!({"error":"batch worker stopped"})), | |
| )); | |
| } | |
| match rx.await { | |
| Ok(Ok(result)) => Ok(result), | |
| Ok(Err(err)) => Err(( | |
| StatusCode::INTERNAL_SERVER_ERROR, | |
| Json(serde_json::json!({"error":err})), | |
| )), | |
| Err(_) => Err(( | |
| StatusCode::INTERNAL_SERVER_ERROR, | |
| Json(serde_json::json!({"error":"response channel closed"})), | |
| )), | |
| } | |
| } | |
| async fn infer(State(state): State<AppState>, Json(req): Json<InferRequest>) -> impl IntoResponse { | |
| let score_threshold = | |
| score_threshold_or_default(req.score_threshold, state.default_score_threshold); | |
| match enqueue_page( | |
| &state, | |
| None, | |
| req.return_boxes.unwrap_or(true), | |
| score_threshold, | |
| ) | |
| .await | |
| { | |
| Ok(result) => Json(serde_json::json!({"pages":1,"results":[result]})).into_response(), | |
| Err(err) => err.into_response(), | |
| } | |
| } | |
| async fn layout( | |
| State(state): State<AppState>, | |
| Query(params): Query<LayoutParams>, | |
| mut multipart: Multipart, | |
| ) -> impl IntoResponse { | |
| let score_threshold = | |
| score_threshold_or_default(params.score_threshold, state.default_score_threshold); | |
| let mut preprocess_tasks = Vec::new(); | |
| while let Some(field) = match multipart.next_field().await { | |
| Ok(field) => field, | |
| Err(err) => { | |
| return ( | |
| StatusCode::BAD_REQUEST, | |
| Json(serde_json::json!({"error":err.to_string()})), | |
| ) | |
| .into_response(); | |
| } | |
| } { | |
| let Some(name) = field.name().map(str::to_string) else { | |
| continue; | |
| }; | |
| if name != "file" && name != "files" { | |
| continue; | |
| } | |
| let bytes = match field.bytes().await { | |
| Ok(bytes) => bytes, | |
| Err(err) => { | |
| return ( | |
| StatusCode::BAD_REQUEST, | |
| Json(serde_json::json!({"error":err.to_string()})), | |
| ) | |
| .into_response(); | |
| } | |
| }; | |
| preprocess_tasks.push(task::spawn_blocking(move || { | |
| preprocess_image_bytes(&bytes, false) | |
| })); | |
| } | |
| if preprocess_tasks.is_empty() { | |
| return ( | |
| StatusCode::BAD_REQUEST, | |
| Json(serde_json::json!({"error":"multipart field 'file' or 'files' is required"})), | |
| ) | |
| .into_response(); | |
| } | |
| let mut inputs = Vec::with_capacity(preprocess_tasks.len()); | |
| for handle in preprocess_tasks { | |
| match handle.await { | |
| Ok(Ok(input)) => inputs.push(input), | |
| Ok(Err(err)) => { | |
| return ( | |
| StatusCode::BAD_REQUEST, | |
| Json(serde_json::json!({"error":err.to_string()})), | |
| ) | |
| .into_response(); | |
| } | |
| Err(err) => { | |
| return ( | |
| StatusCode::INTERNAL_SERVER_ERROR, | |
| Json(serde_json::json!({"error":format!("preprocess worker failed: {err}")})), | |
| ) | |
| .into_response(); | |
| } | |
| } | |
| } | |
| let mut results = Vec::with_capacity(inputs.len()); | |
| for input in inputs { | |
| match enqueue_page(&state, Some(input.clone()), true, score_threshold).await { | |
| Ok(result) => results.push(result), | |
| Err(err) => return err.into_response(), | |
| } | |
| } | |
| Json(serde_json::json!({"pages":results.len(),"results":results})).into_response() | |
| } | |
| async fn layout_tensor( | |
| State(state): State<AppState>, | |
| Query(params): Query<TensorLayoutParams>, | |
| body: Bytes, | |
| ) -> impl IntoResponse { | |
| let original_width = params.original_width.unwrap_or(params.width); | |
| let original_height = params.original_height.unwrap_or(params.height); | |
| let input = match task::spawn_blocking(move || { | |
| preprocess_model_rgb_bytes( | |
| &body, | |
| params.width, | |
| params.height, | |
| original_width, | |
| original_height, | |
| false, | |
| ) | |
| }) | |
| .await | |
| { | |
| Ok(Ok(input)) => input, | |
| Ok(Err(err)) => { | |
| return ( | |
| StatusCode::BAD_REQUEST, | |
| Json(serde_json::json!({"error":err.to_string()})), | |
| ) | |
| .into_response(); | |
| } | |
| Err(err) => { | |
| return ( | |
| StatusCode::INTERNAL_SERVER_ERROR, | |
| Json(serde_json::json!({"error":format!("preprocess worker failed: {err}")})), | |
| ) | |
| .into_response(); | |
| } | |
| }; | |
| let score_threshold = | |
| score_threshold_or_default(params.score_threshold, state.default_score_threshold); | |
| match enqueue_page(&state, Some(input.clone()), true, score_threshold).await { | |
| Ok(result) => Json(serde_json::json!({"pages":1,"results":[result]})).into_response(), | |
| Err(err) => err.into_response(), | |
| } | |
| } | |
| async fn layout_chw_u8( | |
| State(state): State<AppState>, | |
| Query(params): Query<TensorLayoutParams>, | |
| body: Bytes, | |
| ) -> impl IntoResponse { | |
| let original_width = params.original_width.unwrap_or(params.width); | |
| let original_height = params.original_height.unwrap_or(params.height); | |
| let input = match task::spawn_blocking(move || { | |
| preprocess_model_chw_u8_bytes( | |
| &body, | |
| params.width, | |
| params.height, | |
| original_width, | |
| original_height, | |
| false, | |
| ) | |
| }) | |
| .await | |
| { | |
| Ok(Ok(input)) => input, | |
| Ok(Err(err)) => { | |
| return ( | |
| StatusCode::BAD_REQUEST, | |
| Json(serde_json::json!({"error":err.to_string()})), | |
| ) | |
| .into_response(); | |
| } | |
| Err(err) => { | |
| return ( | |
| StatusCode::INTERNAL_SERVER_ERROR, | |
| Json(serde_json::json!({"error":format!("preprocess worker failed: {err}")})), | |
| ) | |
| .into_response(); | |
| } | |
| }; | |
| let score_threshold = | |
| score_threshold_or_default(params.score_threshold, state.default_score_threshold); | |
| match enqueue_page(&state, Some(input.clone()), true, score_threshold).await { | |
| Ok(result) => Json(serde_json::json!({"pages":1,"results":[result]})).into_response(), | |
| Err(err) => err.into_response(), | |
| } | |
| } | |
| async fn layout_chw_u8_batch( | |
| State(state): State<AppState>, | |
| Query(params): Query<BatchTensorLayoutParams>, | |
| body: Bytes, | |
| ) -> impl IntoResponse { | |
| let original_width = params.original_width.unwrap_or(params.width); | |
| let original_height = params.original_height.unwrap_or(params.height); | |
| let batch = params.batch; | |
| let inputs = match task::spawn_blocking(move || { | |
| preprocess_model_chw_u8_batch_bytes( | |
| &body, | |
| batch, | |
| params.width, | |
| params.height, | |
| original_width, | |
| original_height, | |
| false, | |
| ) | |
| }) | |
| .await | |
| { | |
| Ok(Ok(inputs)) => inputs, | |
| Ok(Err(err)) => { | |
| return ( | |
| StatusCode::BAD_REQUEST, | |
| Json(serde_json::json!({"error":err.to_string()})), | |
| ) | |
| .into_response(); | |
| } | |
| Err(err) => { | |
| return ( | |
| StatusCode::INTERNAL_SERVER_ERROR, | |
| Json(serde_json::json!({"error":format!("preprocess worker failed: {err}")})), | |
| ) | |
| .into_response(); | |
| } | |
| }; | |
| let score_threshold = | |
| score_threshold_or_default(params.score_threshold, state.default_score_threshold); | |
| match enqueue_pages(&state, inputs.clone(), true, score_threshold).await { | |
| Ok(results) => { | |
| Json(serde_json::json!({"pages":results.len(),"results":results})).into_response() | |
| } | |
| Err(err) => err.into_response(), | |
| } | |
| } | |
| async fn layout_chw_f32( | |
| State(state): State<AppState>, | |
| Query(params): Query<TensorLayoutParams>, | |
| body: Bytes, | |
| ) -> impl IntoResponse { | |
| let original_width = params.original_width.unwrap_or(params.width); | |
| let original_height = params.original_height.unwrap_or(params.height); | |
| let input = match task::spawn_blocking(move || { | |
| preprocess_model_chw_f32_bytes( | |
| &body, | |
| params.width, | |
| params.height, | |
| original_width, | |
| original_height, | |
| false, | |
| ) | |
| }) | |
| .await | |
| { | |
| Ok(Ok(input)) => input, | |
| Ok(Err(err)) => { | |
| return ( | |
| StatusCode::BAD_REQUEST, | |
| Json(serde_json::json!({"error":err.to_string()})), | |
| ) | |
| .into_response(); | |
| } | |
| Err(err) => { | |
| return ( | |
| StatusCode::INTERNAL_SERVER_ERROR, | |
| Json(serde_json::json!({"error":format!("preprocess worker failed: {err}")})), | |
| ) | |
| .into_response(); | |
| } | |
| }; | |
| let score_threshold = | |
| score_threshold_or_default(params.score_threshold, state.default_score_threshold); | |
| match enqueue_page(&state, Some(input.clone()), true, score_threshold).await { | |
| Ok(result) => Json(serde_json::json!({"pages":1,"results":[result]})).into_response(), | |
| Err(err) => err.into_response(), | |
| } | |
| } | |
| async fn main() -> Result<()> { | |
| tracing_subscriber::fmt() | |
| .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) | |
| .init(); | |
| let engine_path = env::var("DOC_LAYOUT_ENGINE").context("DOC_LAYOUT_ENGINE is required")?; | |
| let sample_path = env::var("DOC_LAYOUT_SAMPLE_IMAGE").ok(); | |
| let max_batch = env::var("DOC_LAYOUT_MAX_BATCH") | |
| .ok() | |
| .and_then(|v| v.parse().ok()) | |
| .unwrap_or(8); | |
| let max_delay_us = env::var("DOC_LAYOUT_MAX_DELAY_US") | |
| .ok() | |
| .and_then(|v| v.parse().ok()) | |
| .unwrap_or(1000u64); | |
| let queue_capacity = env::var("DOC_LAYOUT_QUEUE_CAPACITY") | |
| .ok() | |
| .and_then(|v| v.parse().ok()) | |
| .unwrap_or(4096usize); | |
| let default_score_threshold = env::var("DOC_LAYOUT_SCORE_THRESHOLD") | |
| .ok() | |
| .and_then(|v| v.parse::<f32>().ok()) | |
| .unwrap_or(0.5) | |
| .clamp(0.0, 1.0); | |
| let max_upload_mb = env::var("DOC_LAYOUT_MAX_UPLOAD_MB") | |
| .ok() | |
| .and_then(|v| v.parse::<usize>().ok()) | |
| .unwrap_or(512); | |
| let workers = env::var("DOC_LAYOUT_WORKERS") | |
| .ok() | |
| .and_then(|v| v.parse().ok()) | |
| .unwrap_or(1usize) | |
| .max(1); | |
| let port = env::var("DOC_LAYOUT_PORT") | |
| .ok() | |
| .and_then(|v| v.parse().ok()) | |
| .unwrap_or(18082u16); | |
| let sample = load_sample(sample_path.as_deref())?; | |
| if let Some(iters) = env::var("DOC_LAYOUT_SELF_TEST_ITERS") | |
| .ok() | |
| .and_then(|v| v.parse::<usize>().ok()) | |
| { | |
| let batch = env::var("DOC_LAYOUT_SELF_TEST_BATCH") | |
| .ok() | |
| .and_then(|v| v.parse::<usize>().ok()) | |
| .unwrap_or(max_batch); | |
| return run_self_test(&engine_path, &sample, batch, iters); | |
| } | |
| let (tx, rx) = bounded::<WorkItem>(queue_capacity); | |
| let metrics = Arc::new(Metrics::default()); | |
| for worker_id in 0..workers { | |
| let worker_rx = rx.clone(); | |
| let worker_metrics = metrics.clone(); | |
| let worker_engine_path = engine_path.clone(); | |
| let worker_sample = sample.clone(); | |
| thread::spawn(move || { | |
| tracing::info!(worker_id, "starting batch worker"); | |
| batch_worker( | |
| worker_rx, | |
| worker_metrics, | |
| worker_engine_path, | |
| worker_sample, | |
| max_batch, | |
| Duration::from_micros(max_delay_us), | |
| ); | |
| }); | |
| } | |
| let state = AppState { | |
| tx, | |
| metrics, | |
| default_score_threshold, | |
| }; | |
| let app = Router::new() | |
| .route("/health", get(health)) | |
| .route("/metrics", get(metrics_handler)) | |
| .route("/v1/infer", post(infer)) | |
| .route("/v1/layout", post(layout)) | |
| .route("/v1/layout_tensor", post(layout_tensor)) | |
| .route("/v1/layout_chw_f32", post(layout_chw_f32)) | |
| .route("/v1/layout_chw_u8", post(layout_chw_u8)) | |
| .route("/v1/layout_chw_u8_batch", post(layout_chw_u8_batch)) | |
| .layer(DefaultBodyLimit::max(max_upload_mb * 1024 * 1024)) | |
| .with_state(state); | |
| let addr = SocketAddr::from(([0, 0, 0, 0], port)); | |
| tracing::info!( | |
| %addr, | |
| workers, | |
| max_batch, | |
| max_delay_us, | |
| max_upload_mb, | |
| default_score_threshold, | |
| "starting doclayout TensorRT batcher" | |
| ); | |
| let listener = tokio::net::TcpListener::bind(addr).await?; | |
| axum::serve(listener, app).await?; | |
| Ok(()) | |
| } | |