bndos's picture
Add pp-doclayout server source with score threshold
3c0d3e1 verified
use anyhow::{anyhow, Context, Result};
use axum::{
body::Bytes,
extract::{DefaultBodyLimit, Multipart, Query, State},
http::StatusCode,
response::IntoResponse,
routing::{get, post},
Json, Router,
};
use crossbeam_channel::{bounded, Receiver, Sender};
use serde::{Deserialize, Serialize};
use std::{
env,
ffi::CString,
net::SocketAddr,
os::raw::{c_char, c_float, c_int},
path::Path,
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
thread,
time::{Duration, Instant},
};
use tokio::{sync::oneshot, task};
#[repr(C)]
struct TrtContextOpaque {
_private: [u8; 0],
}
extern "C" {
fn trt_create(engine_path: *const c_char) -> *mut TrtContextOpaque;
fn trt_destroy(ctx: *mut TrtContextOpaque);
fn trt_max_batch(ctx: *mut TrtContextOpaque) -> c_int;
fn trt_infer(
ctx: *mut TrtContextOpaque,
image: *const c_float,
im_shape: *const c_float,
scale_factor: *const c_float,
batch: c_int,
boxes_out: *mut c_float,
counts_out: *mut i32,
) -> c_int;
}
const LABELS: [&str; 25] = [
"abstract",
"algorithm",
"aside_text",
"chart",
"content",
"display_formula",
"doc_title",
"figure_title",
"footer",
"footer_image",
"footnote",
"formula_number",
"header",
"header_image",
"image",
"inline_formula",
"number",
"paragraph_title",
"reference",
"reference_content",
"seal",
"table",
"text",
"vertical_text",
"vision_footnote",
];
struct Engine(*mut TrtContextOpaque);
unsafe impl Send for Engine {}
impl Engine {
fn load(path: &str) -> Result<Self> {
let c_path = CString::new(path)?;
let ptr = unsafe { trt_create(c_path.as_ptr()) };
if ptr.is_null() {
return Err(anyhow!("failed to load TensorRT engine: {path}"));
}
Ok(Self(ptr))
}
fn max_batch(&self) -> usize {
unsafe { trt_max_batch(self.0).max(1) as usize }
}
fn infer(
&mut self,
image: &[f32],
im_shape: &[f32],
scale_factor: &[f32],
batch: usize,
) -> Result<(Vec<f32>, Vec<i32>)> {
let mut boxes = vec![0.0f32; batch * 300 * 7];
let mut counts = vec![0i32; batch];
let code = unsafe {
trt_infer(
self.0,
image.as_ptr(),
im_shape.as_ptr(),
scale_factor.as_ptr(),
batch as c_int,
boxes.as_mut_ptr(),
counts.as_mut_ptr(),
)
};
if code != 0 {
return Err(anyhow!("trt_infer failed with code {code}"));
}
Ok((boxes, counts))
}
}
impl Drop for Engine {
fn drop(&mut self) {
unsafe { trt_destroy(self.0) }
}
}
#[derive(Clone)]
struct PageInput {
image: Arc<Vec<f32>>,
im_shape: [f32; 2],
scale_factor: [f32; 2],
}
#[derive(Default)]
struct Metrics {
requests: AtomicU64,
batches: AtomicU64,
pages: AtomicU64,
errors: AtomicU64,
total_batch_wait_us: AtomicU64,
total_infer_us: AtomicU64,
}
#[derive(Clone)]
struct AppState {
tx: Sender<WorkItem>,
metrics: Arc<Metrics>,
default_score_threshold: f32,
}
struct WorkItem {
enqueued: Instant,
input: Option<PageInput>,
return_boxes: bool,
score_threshold: f32,
respond_to: oneshot::Sender<Result<PageResult, String>>,
}
#[derive(Debug, Serialize, Clone)]
struct BoxResult {
label: String,
class_id: i32,
score: f32,
bbox: [i32; 4],
order: i32,
source: String,
}
#[derive(Debug, Serialize, Clone)]
struct PageResult {
boxes: Vec<BoxResult>,
batch_size: usize,
queue_wait_us: u128,
infer_us: u128,
}
#[derive(Debug, Serialize)]
struct MetricsResponse {
requests: u64,
batches: u64,
pages: u64,
errors: u64,
avg_batch_size: f64,
avg_queue_wait_us: f64,
avg_infer_us_per_batch: f64,
}
#[derive(Debug, Deserialize)]
struct InferRequest {
return_boxes: Option<bool>,
score_threshold: Option<f32>,
}
#[derive(Debug, Deserialize)]
struct LayoutParams {
score_threshold: Option<f32>,
}
#[derive(Debug, Deserialize)]
struct TensorLayoutParams {
width: u32,
height: u32,
original_width: Option<u32>,
original_height: Option<u32>,
score_threshold: Option<f32>,
}
#[derive(Debug, Deserialize)]
struct BatchTensorLayoutParams {
batch: usize,
width: u32,
height: u32,
original_width: Option<u32>,
original_height: Option<u32>,
score_threshold: Option<f32>,
}
fn preprocess_rgb(
image: image::RgbImage,
width: u32,
height: u32,
_keep_original: bool,
) -> PageInput {
let target = 800u32;
let resized = image::imageops::resize(
&image,
target,
target,
image::imageops::FilterType::Triangle,
);
let mut chw = vec![0.0f32; 3 * target as usize * target as usize];
let plane = (target * target) as usize;
for y in 0..target as usize {
for x in 0..target as usize {
let px = resized.get_pixel(x as u32, y as u32).0;
let idx = y * target as usize + x;
chw[idx] = px[0] as f32 / 255.0;
chw[plane + idx] = px[1] as f32 / 255.0;
chw[2 * plane + idx] = px[2] as f32 / 255.0;
}
}
PageInput {
image: Arc::new(chw),
im_shape: [target as f32, target as f32],
scale_factor: [target as f32 / height as f32, target as f32 / width as f32],
}
}
fn load_sample(path: Option<&str>) -> Result<PageInput> {
if let Some(path) = path {
let image = image::open(Path::new(path))
.with_context(|| format!("open sample image {path}"))?
.to_rgb8();
let (width, height) = image.dimensions();
Ok(preprocess_rgb(image, width, height, false))
} else {
let target = 800u32;
let image = image::RgbImage::from_pixel(target, target, image::Rgb([255, 255, 255]));
Ok(preprocess_rgb(image, target, target, false))
}
}
fn preprocess_image_bytes(bytes: &[u8], _keep_original: bool) -> Result<PageInput> {
let image = image::load_from_memory(bytes)?.to_rgb8();
let (width, height) = image.dimensions();
Ok(preprocess_rgb(image, width, height, false))
}
fn preprocess_model_rgb_bytes(
bytes: &[u8],
width: u32,
height: u32,
original_width: u32,
original_height: u32,
_keep_original: bool,
) -> Result<PageInput> {
let expected = width as usize * height as usize * 3;
if width != 800 || height != 800 {
return Err(anyhow!(
"model RGB input must be 800x800, got {width}x{height}"
));
}
if bytes.len() != expected {
return Err(anyhow!(
"RGB body has {} bytes, expected {} for {}x{}x3",
bytes.len(),
expected,
width,
height
));
}
let mut chw = vec![0.0f32; expected];
let plane = (width * height) as usize;
for idx in 0..plane {
let src = idx * 3;
chw[idx] = bytes[src] as f32 / 255.0;
chw[plane + idx] = bytes[src + 1] as f32 / 255.0;
chw[2 * plane + idx] = bytes[src + 2] as f32 / 255.0;
}
Ok(PageInput {
image: Arc::new(chw),
im_shape: [height as f32, width as f32],
scale_factor: [
height as f32 / original_height.max(1) as f32,
width as f32 / original_width.max(1) as f32,
],
})
}
fn preprocess_model_chw_u8_bytes(
bytes: &[u8],
width: u32,
height: u32,
original_width: u32,
original_height: u32,
_keep_original: bool,
) -> Result<PageInput> {
if width != 800 || height != 800 {
return Err(anyhow!(
"model CHW u8 input must be 800x800, got {width}x{height}"
));
}
let expected = width as usize * height as usize * 3;
if bytes.len() != expected {
return Err(anyhow!(
"CHW u8 body has {} bytes, expected {} for 3x{}x{}",
bytes.len(),
expected,
height,
width
));
}
let mut chw = vec![0.0f32; expected];
for (dst, src) in chw.iter_mut().zip(bytes.iter()) {
*dst = *src as f32 / 255.0;
}
Ok(PageInput {
image: Arc::new(chw),
im_shape: [height as f32, width as f32],
scale_factor: [
height as f32 / original_height.max(1) as f32,
width as f32 / original_width.max(1) as f32,
],
})
}
fn preprocess_model_chw_u8_batch_bytes(
bytes: &[u8],
batch: usize,
width: u32,
height: u32,
original_width: u32,
original_height: u32,
_keep_original: bool,
) -> Result<Vec<PageInput>> {
if batch == 0 {
return Err(anyhow!("batch must be > 0"));
}
let per_page = width as usize * height as usize * 3;
let expected = batch * per_page;
if bytes.len() != expected {
return Err(anyhow!(
"batched CHW u8 body has {} bytes, expected {} for batch={} 3x{}x{}",
bytes.len(),
expected,
batch,
height,
width
));
}
(0..batch)
.into_iter()
.map(|idx| {
let start = idx * per_page;
let end = start + per_page;
preprocess_model_chw_u8_bytes(
&bytes[start..end],
width,
height,
original_width,
original_height,
false,
)
})
.collect()
}
fn preprocess_model_chw_f32_bytes(
bytes: &[u8],
width: u32,
height: u32,
original_width: u32,
original_height: u32,
_keep_original: bool,
) -> Result<PageInput> {
if width != 800 || height != 800 {
return Err(anyhow!(
"model CHW f32 input must be 800x800, got {width}x{height}"
));
}
let floats = width as usize * height as usize * 3;
let expected = floats * std::mem::size_of::<f32>();
if bytes.len() != expected {
return Err(anyhow!(
"CHW f32 body has {} bytes, expected {} for 3x{}x{}",
bytes.len(),
expected,
height,
width
));
}
let mut chw = vec![0.0f32; floats];
unsafe {
std::ptr::copy_nonoverlapping(bytes.as_ptr(), chw.as_mut_ptr() as *mut u8, expected);
}
Ok(PageInput {
image: Arc::new(chw),
im_shape: [height as f32, width as f32],
scale_factor: [
height as f32 / original_height.max(1) as f32,
width as f32 / original_width.max(1) as f32,
],
})
}
fn decode_page(
boxes: &[f32],
count: usize,
return_boxes: bool,
score_threshold: f32,
) -> Vec<BoxResult> {
if !return_boxes {
return Vec::new();
}
let mut out = Vec::new();
let limit = count.min(300);
for i in 0..limit {
let row = &boxes[i * 7..i * 7 + 7];
let class_id = row[0] as i32;
let score = row[1];
if score < score_threshold || class_id < 0 || class_id as usize >= LABELS.len() {
continue;
}
out.push(BoxResult {
label: LABELS[class_id as usize].to_string(),
class_id,
score,
bbox: [
row[2].round() as i32,
row[3].round() as i32,
row[4].round() as i32,
row[5].round() as i32,
],
order: row[6] as i32,
source: "pp_doclayout_v3".to_string(),
});
}
out.sort_by_key(|b| b.order);
out
}
fn push_input(
dst_image: &mut Vec<f32>,
dst_shape: &mut Vec<f32>,
dst_scale: &mut Vec<f32>,
input: &PageInput,
) {
dst_image.extend_from_slice(&input.image);
dst_shape.extend_from_slice(&input.im_shape);
dst_scale.extend_from_slice(&input.scale_factor);
}
fn batch_worker(
rx: Receiver<WorkItem>,
metrics: Arc<Metrics>,
engine_path: String,
sample: PageInput,
max_batch_cfg: usize,
max_delay: Duration,
) {
let mut engine = Engine::load(&engine_path).expect("load TensorRT engine");
let max_batch = max_batch_cfg.min(engine.max_batch()).max(1);
let mut cached_images: Vec<Vec<f32>> = Vec::with_capacity(max_batch + 1);
let mut cached_shapes: Vec<Vec<f32>> = Vec::with_capacity(max_batch + 1);
let mut cached_scales: Vec<Vec<f32>> = Vec::with_capacity(max_batch + 1);
cached_images.push(Vec::new());
cached_shapes.push(Vec::new());
cached_scales.push(Vec::new());
for b in 1..=max_batch {
let mut image = Vec::with_capacity(b * sample.image.len());
let mut im_shape = Vec::with_capacity(b * 2);
let mut scale = Vec::with_capacity(b * 2);
for _ in 0..b {
push_input(&mut image, &mut im_shape, &mut scale, &sample);
}
cached_images.push(image);
cached_shapes.push(im_shape);
cached_scales.push(scale);
}
tracing::info!(
engine_path,
max_batch,
delay_us = max_delay.as_micros(),
"rust batcher ready"
);
loop {
let first = match rx.recv() {
Ok(item) => item,
Err(_) => break,
};
let mut items = vec![first];
let deadline = Instant::now() + max_delay;
while items.len() < max_batch {
let now = Instant::now();
if now >= deadline {
break;
}
match rx.recv_timeout(deadline - now) {
Ok(item) => items.push(item),
Err(_) => break,
}
}
let batch = items.len();
let all_sample = items.iter().all(|item| item.input.is_none());
let mut image_owned = Vec::new();
let mut shape_owned = Vec::new();
let mut scale_owned = Vec::new();
let (image_ref, shape_ref, scale_ref) = if all_sample {
(
&cached_images[batch],
&cached_shapes[batch],
&cached_scales[batch],
)
} else {
image_owned.reserve(batch * sample.image.len());
shape_owned.reserve(batch * 2);
scale_owned.reserve(batch * 2);
for item in &items {
let input = item.input.as_ref().unwrap_or(&sample);
push_input(&mut image_owned, &mut shape_owned, &mut scale_owned, input);
}
(&image_owned, &shape_owned, &scale_owned)
};
let infer_start = Instant::now();
let infer_result = engine.infer(image_ref, shape_ref, scale_ref, batch);
let infer_us = infer_start.elapsed().as_micros();
metrics.batches.fetch_add(1, Ordering::Relaxed);
metrics.pages.fetch_add(batch as u64, Ordering::Relaxed);
metrics
.total_infer_us
.fetch_add(infer_us as u64, Ordering::Relaxed);
match infer_result {
Ok((boxes, counts)) => {
for (idx, item) in items.into_iter().enumerate() {
let wait_us = infer_start.duration_since(item.enqueued).as_micros();
metrics
.total_batch_wait_us
.fetch_add(wait_us as u64, Ordering::Relaxed);
let start = idx * 300 * 7;
let end = start + 300 * 7;
let decoded = decode_page(
&boxes[start..end],
counts[idx].max(0) as usize,
item.return_boxes,
item.score_threshold,
);
let _ = item.respond_to.send(Ok(PageResult {
boxes: decoded,
batch_size: batch,
queue_wait_us: wait_us,
infer_us,
}));
}
}
Err(e) => {
metrics.errors.fetch_add(batch as u64, Ordering::Relaxed);
let msg = e.to_string();
for item in items {
let _ = item.respond_to.send(Err(msg.clone()));
}
}
}
}
}
fn run_self_test(engine_path: &str, sample: &PageInput, batch: usize, iters: usize) -> Result<()> {
let mut engine = Engine::load(engine_path)?;
let max_batch = batch.min(engine.max_batch()).max(1);
let mut image = Vec::with_capacity(max_batch * sample.image.len());
let mut im_shape = Vec::with_capacity(max_batch * 2);
let mut scale = Vec::with_capacity(max_batch * 2);
for _ in 0..max_batch {
push_input(&mut image, &mut im_shape, &mut scale, sample);
}
for _ in 0..iters {
let (_boxes, counts) = engine.infer(&image, &im_shape, &scale, max_batch)?;
if counts.len() != max_batch {
return Err(anyhow!(
"unexpected count length {} for batch {}",
counts.len(),
max_batch
));
}
}
println!("self_test_ok batch={max_batch} iters={iters}");
Ok(())
}
async fn health() -> impl IntoResponse {
Json(serde_json::json!({"status":"ok"}))
}
async fn metrics_handler(State(state): State<AppState>) -> impl IntoResponse {
let requests = state.metrics.requests.load(Ordering::Relaxed);
let batches = state.metrics.batches.load(Ordering::Relaxed);
let pages = state.metrics.pages.load(Ordering::Relaxed);
let errors = state.metrics.errors.load(Ordering::Relaxed);
let wait = state.metrics.total_batch_wait_us.load(Ordering::Relaxed);
let infer = state.metrics.total_infer_us.load(Ordering::Relaxed);
Json(MetricsResponse {
requests,
batches,
pages,
errors,
avg_batch_size: if batches == 0 {
0.0
} else {
pages as f64 / batches as f64
},
avg_queue_wait_us: if requests == 0 {
0.0
} else {
wait as f64 / requests as f64
},
avg_infer_us_per_batch: if batches == 0 {
0.0
} else {
infer as f64 / batches as f64
},
})
}
fn score_threshold_or_default(value: Option<f32>, default: f32) -> f32 {
value.unwrap_or(default).clamp(0.0, 1.0)
}
async fn enqueue_pages(
state: &AppState,
inputs: Vec<PageInput>,
return_boxes: bool,
score_threshold: f32,
) -> Result<Vec<PageResult>, (StatusCode, Json<serde_json::Value>)> {
let mut receivers = Vec::with_capacity(inputs.len());
for input in inputs {
let (tx, rx) = oneshot::channel();
state.metrics.requests.fetch_add(1, Ordering::Relaxed);
if state
.tx
.send(WorkItem {
enqueued: Instant::now(),
input: Some(input),
return_boxes,
score_threshold,
respond_to: tx,
})
.is_err()
{
state.metrics.errors.fetch_add(1, Ordering::Relaxed);
return Err((
StatusCode::SERVICE_UNAVAILABLE,
Json(serde_json::json!({"error":"batch worker stopped"})),
));
}
receivers.push(rx);
}
let mut results = Vec::with_capacity(receivers.len());
for rx in receivers {
match rx.await {
Ok(Ok(result)) => results.push(result),
Ok(Err(err)) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({"error":err})),
));
}
Err(_) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({"error":"response channel closed"})),
));
}
}
}
Ok(results)
}
async fn enqueue_page(
state: &AppState,
input: Option<PageInput>,
return_boxes: bool,
score_threshold: f32,
) -> Result<PageResult, (StatusCode, Json<serde_json::Value>)> {
let (tx, rx) = oneshot::channel();
state.metrics.requests.fetch_add(1, Ordering::Relaxed);
if state
.tx
.send(WorkItem {
enqueued: Instant::now(),
input,
return_boxes,
score_threshold,
respond_to: tx,
})
.is_err()
{
state.metrics.errors.fetch_add(1, Ordering::Relaxed);
return Err((
StatusCode::SERVICE_UNAVAILABLE,
Json(serde_json::json!({"error":"batch worker stopped"})),
));
}
match rx.await {
Ok(Ok(result)) => Ok(result),
Ok(Err(err)) => Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({"error":err})),
)),
Err(_) => Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({"error":"response channel closed"})),
)),
}
}
async fn infer(State(state): State<AppState>, Json(req): Json<InferRequest>) -> impl IntoResponse {
let score_threshold =
score_threshold_or_default(req.score_threshold, state.default_score_threshold);
match enqueue_page(
&state,
None,
req.return_boxes.unwrap_or(true),
score_threshold,
)
.await
{
Ok(result) => Json(serde_json::json!({"pages":1,"results":[result]})).into_response(),
Err(err) => err.into_response(),
}
}
async fn layout(
State(state): State<AppState>,
Query(params): Query<LayoutParams>,
mut multipart: Multipart,
) -> impl IntoResponse {
let score_threshold =
score_threshold_or_default(params.score_threshold, state.default_score_threshold);
let mut preprocess_tasks = Vec::new();
while let Some(field) = match multipart.next_field().await {
Ok(field) => field,
Err(err) => {
return (
StatusCode::BAD_REQUEST,
Json(serde_json::json!({"error":err.to_string()})),
)
.into_response();
}
} {
let Some(name) = field.name().map(str::to_string) else {
continue;
};
if name != "file" && name != "files" {
continue;
}
let bytes = match field.bytes().await {
Ok(bytes) => bytes,
Err(err) => {
return (
StatusCode::BAD_REQUEST,
Json(serde_json::json!({"error":err.to_string()})),
)
.into_response();
}
};
preprocess_tasks.push(task::spawn_blocking(move || {
preprocess_image_bytes(&bytes, false)
}));
}
if preprocess_tasks.is_empty() {
return (
StatusCode::BAD_REQUEST,
Json(serde_json::json!({"error":"multipart field 'file' or 'files' is required"})),
)
.into_response();
}
let mut inputs = Vec::with_capacity(preprocess_tasks.len());
for handle in preprocess_tasks {
match handle.await {
Ok(Ok(input)) => inputs.push(input),
Ok(Err(err)) => {
return (
StatusCode::BAD_REQUEST,
Json(serde_json::json!({"error":err.to_string()})),
)
.into_response();
}
Err(err) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({"error":format!("preprocess worker failed: {err}")})),
)
.into_response();
}
}
}
let mut results = Vec::with_capacity(inputs.len());
for input in inputs {
match enqueue_page(&state, Some(input.clone()), true, score_threshold).await {
Ok(result) => results.push(result),
Err(err) => return err.into_response(),
}
}
Json(serde_json::json!({"pages":results.len(),"results":results})).into_response()
}
async fn layout_tensor(
State(state): State<AppState>,
Query(params): Query<TensorLayoutParams>,
body: Bytes,
) -> impl IntoResponse {
let original_width = params.original_width.unwrap_or(params.width);
let original_height = params.original_height.unwrap_or(params.height);
let input = match task::spawn_blocking(move || {
preprocess_model_rgb_bytes(
&body,
params.width,
params.height,
original_width,
original_height,
false,
)
})
.await
{
Ok(Ok(input)) => input,
Ok(Err(err)) => {
return (
StatusCode::BAD_REQUEST,
Json(serde_json::json!({"error":err.to_string()})),
)
.into_response();
}
Err(err) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({"error":format!("preprocess worker failed: {err}")})),
)
.into_response();
}
};
let score_threshold =
score_threshold_or_default(params.score_threshold, state.default_score_threshold);
match enqueue_page(&state, Some(input.clone()), true, score_threshold).await {
Ok(result) => Json(serde_json::json!({"pages":1,"results":[result]})).into_response(),
Err(err) => err.into_response(),
}
}
async fn layout_chw_u8(
State(state): State<AppState>,
Query(params): Query<TensorLayoutParams>,
body: Bytes,
) -> impl IntoResponse {
let original_width = params.original_width.unwrap_or(params.width);
let original_height = params.original_height.unwrap_or(params.height);
let input = match task::spawn_blocking(move || {
preprocess_model_chw_u8_bytes(
&body,
params.width,
params.height,
original_width,
original_height,
false,
)
})
.await
{
Ok(Ok(input)) => input,
Ok(Err(err)) => {
return (
StatusCode::BAD_REQUEST,
Json(serde_json::json!({"error":err.to_string()})),
)
.into_response();
}
Err(err) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({"error":format!("preprocess worker failed: {err}")})),
)
.into_response();
}
};
let score_threshold =
score_threshold_or_default(params.score_threshold, state.default_score_threshold);
match enqueue_page(&state, Some(input.clone()), true, score_threshold).await {
Ok(result) => Json(serde_json::json!({"pages":1,"results":[result]})).into_response(),
Err(err) => err.into_response(),
}
}
async fn layout_chw_u8_batch(
State(state): State<AppState>,
Query(params): Query<BatchTensorLayoutParams>,
body: Bytes,
) -> impl IntoResponse {
let original_width = params.original_width.unwrap_or(params.width);
let original_height = params.original_height.unwrap_or(params.height);
let batch = params.batch;
let inputs = match task::spawn_blocking(move || {
preprocess_model_chw_u8_batch_bytes(
&body,
batch,
params.width,
params.height,
original_width,
original_height,
false,
)
})
.await
{
Ok(Ok(inputs)) => inputs,
Ok(Err(err)) => {
return (
StatusCode::BAD_REQUEST,
Json(serde_json::json!({"error":err.to_string()})),
)
.into_response();
}
Err(err) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({"error":format!("preprocess worker failed: {err}")})),
)
.into_response();
}
};
let score_threshold =
score_threshold_or_default(params.score_threshold, state.default_score_threshold);
match enqueue_pages(&state, inputs.clone(), true, score_threshold).await {
Ok(results) => {
Json(serde_json::json!({"pages":results.len(),"results":results})).into_response()
}
Err(err) => err.into_response(),
}
}
async fn layout_chw_f32(
State(state): State<AppState>,
Query(params): Query<TensorLayoutParams>,
body: Bytes,
) -> impl IntoResponse {
let original_width = params.original_width.unwrap_or(params.width);
let original_height = params.original_height.unwrap_or(params.height);
let input = match task::spawn_blocking(move || {
preprocess_model_chw_f32_bytes(
&body,
params.width,
params.height,
original_width,
original_height,
false,
)
})
.await
{
Ok(Ok(input)) => input,
Ok(Err(err)) => {
return (
StatusCode::BAD_REQUEST,
Json(serde_json::json!({"error":err.to_string()})),
)
.into_response();
}
Err(err) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({"error":format!("preprocess worker failed: {err}")})),
)
.into_response();
}
};
let score_threshold =
score_threshold_or_default(params.score_threshold, state.default_score_threshold);
match enqueue_page(&state, Some(input.clone()), true, score_threshold).await {
Ok(result) => Json(serde_json::json!({"pages":1,"results":[result]})).into_response(),
Err(err) => err.into_response(),
}
}
#[tokio::main]
async fn main() -> Result<()> {
tracing_subscriber::fmt()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.init();
let engine_path = env::var("DOC_LAYOUT_ENGINE").context("DOC_LAYOUT_ENGINE is required")?;
let sample_path = env::var("DOC_LAYOUT_SAMPLE_IMAGE").ok();
let max_batch = env::var("DOC_LAYOUT_MAX_BATCH")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(8);
let max_delay_us = env::var("DOC_LAYOUT_MAX_DELAY_US")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(1000u64);
let queue_capacity = env::var("DOC_LAYOUT_QUEUE_CAPACITY")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(4096usize);
let default_score_threshold = env::var("DOC_LAYOUT_SCORE_THRESHOLD")
.ok()
.and_then(|v| v.parse::<f32>().ok())
.unwrap_or(0.5)
.clamp(0.0, 1.0);
let max_upload_mb = env::var("DOC_LAYOUT_MAX_UPLOAD_MB")
.ok()
.and_then(|v| v.parse::<usize>().ok())
.unwrap_or(512);
let workers = env::var("DOC_LAYOUT_WORKERS")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(1usize)
.max(1);
let port = env::var("DOC_LAYOUT_PORT")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(18082u16);
let sample = load_sample(sample_path.as_deref())?;
if let Some(iters) = env::var("DOC_LAYOUT_SELF_TEST_ITERS")
.ok()
.and_then(|v| v.parse::<usize>().ok())
{
let batch = env::var("DOC_LAYOUT_SELF_TEST_BATCH")
.ok()
.and_then(|v| v.parse::<usize>().ok())
.unwrap_or(max_batch);
return run_self_test(&engine_path, &sample, batch, iters);
}
let (tx, rx) = bounded::<WorkItem>(queue_capacity);
let metrics = Arc::new(Metrics::default());
for worker_id in 0..workers {
let worker_rx = rx.clone();
let worker_metrics = metrics.clone();
let worker_engine_path = engine_path.clone();
let worker_sample = sample.clone();
thread::spawn(move || {
tracing::info!(worker_id, "starting batch worker");
batch_worker(
worker_rx,
worker_metrics,
worker_engine_path,
worker_sample,
max_batch,
Duration::from_micros(max_delay_us),
);
});
}
let state = AppState {
tx,
metrics,
default_score_threshold,
};
let app = Router::new()
.route("/health", get(health))
.route("/metrics", get(metrics_handler))
.route("/v1/infer", post(infer))
.route("/v1/layout", post(layout))
.route("/v1/layout_tensor", post(layout_tensor))
.route("/v1/layout_chw_f32", post(layout_chw_f32))
.route("/v1/layout_chw_u8", post(layout_chw_u8))
.route("/v1/layout_chw_u8_batch", post(layout_chw_u8_batch))
.layer(DefaultBodyLimit::max(max_upload_mb * 1024 * 1024))
.with_state(state);
let addr = SocketAddr::from(([0, 0, 0, 0], port));
tracing::info!(
%addr,
workers,
max_batch,
max_delay_us,
max_upload_mb,
default_score_threshold,
"starting doclayout TensorRT batcher"
);
let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(listener, app).await?;
Ok(())
}