Mayo commited on
feat: Rust-based processing
Browse files- koharu/src/app.rs +2 -0
- koharu/src/endpoints.rs +67 -72
- koharu/src/lib.rs +1 -0
- koharu/src/ml.rs +86 -28
- koharu/src/pipeline.rs +261 -0
- koharu/src/server.rs +3 -0
- scripts/dev.ts +1 -1
- ui/components/ActivityBubble.tsx +1 -1
- ui/lib/backend.ts +23 -18
- ui/lib/operations.ts +7 -2
- ui/lib/store.ts +79 -151
koharu/src/app.rs
CHANGED
|
@@ -59,6 +59,7 @@ pub struct AppResources {
|
|
| 59 |
pub llm: Arc<llm::Model>,
|
| 60 |
pub renderer: Arc<Renderer>,
|
| 61 |
pub ml_device: DeviceName,
|
|
|
|
| 62 |
}
|
| 63 |
|
| 64 |
#[derive(Parser)]
|
|
@@ -208,6 +209,7 @@ async fn build_resources(cpu: bool, _register_file_assoc: bool) -> Result<AppRes
|
|
| 208 |
llm,
|
| 209 |
renderer,
|
| 210 |
ml_device,
|
|
|
|
| 211 |
})
|
| 212 |
}
|
| 213 |
|
|
|
|
| 59 |
pub llm: Arc<llm::Model>,
|
| 60 |
pub renderer: Arc<Renderer>,
|
| 61 |
pub ml_device: DeviceName,
|
| 62 |
+
pub pipeline: Arc<RwLock<Option<crate::pipeline::PipelineHandle>>>,
|
| 63 |
}
|
| 64 |
|
| 65 |
#[derive(Parser)]
|
|
|
|
| 209 |
llm,
|
| 210 |
renderer,
|
| 211 |
ml_device,
|
| 212 |
+
pipeline: Arc::new(RwLock::new(None)),
|
| 213 |
})
|
| 214 |
}
|
| 215 |
|
koharu/src/endpoints.rs
CHANGED
|
@@ -27,7 +27,7 @@ use crate::{
|
|
| 27 |
khr::{deserialize_khr, has_khr_magic, serialize_khr},
|
| 28 |
llm,
|
| 29 |
result::Result,
|
| 30 |
-
state::{Document, TextBlock
|
| 31 |
version,
|
| 32 |
};
|
| 33 |
|
|
@@ -194,7 +194,7 @@ pub async fn detect(
|
|
| 194 |
State(state): State<AppResources>,
|
| 195 |
Json(payload): Json<IndexPayload>,
|
| 196 |
) -> Result<Json<()>> {
|
| 197 |
-
let snapshot = {
|
| 198 |
let guard = state.state.read().await;
|
| 199 |
guard
|
| 200 |
.documents
|
|
@@ -203,44 +203,14 @@ pub async fn detect(
|
|
| 203 |
.ok_or_else(|| anyhow::anyhow!("Document not found"))?
|
| 204 |
};
|
| 205 |
|
| 206 |
-
|
| 207 |
-
let mut updated = snapshot.clone();
|
| 208 |
-
updated.text_blocks = text_blocks;
|
| 209 |
-
updated.segment = Some(segment);
|
| 210 |
-
|
| 211 |
-
if !updated.text_blocks.is_empty() {
|
| 212 |
-
let images: Vec<image::DynamicImage> = updated
|
| 213 |
-
.text_blocks
|
| 214 |
-
.iter()
|
| 215 |
-
.map(|block| {
|
| 216 |
-
updated.image.crop_imm(
|
| 217 |
-
block.x as u32,
|
| 218 |
-
block.y as u32,
|
| 219 |
-
block.width as u32,
|
| 220 |
-
block.height as u32,
|
| 221 |
-
)
|
| 222 |
-
})
|
| 223 |
-
.collect();
|
| 224 |
-
|
| 225 |
-
let font_predictions = state.ml.detect_fonts(&images, 1).await?;
|
| 226 |
-
for (block, prediction) in updated.text_blocks.iter_mut().zip(font_predictions) {
|
| 227 |
-
let color = prediction.text_color;
|
| 228 |
-
let font_size = (prediction.font_size_px > 0.0).then_some(prediction.font_size_px);
|
| 229 |
-
block.font_prediction = Some(prediction);
|
| 230 |
-
block.style = Some(TextStyle {
|
| 231 |
-
font_size,
|
| 232 |
-
color: [color[0], color[1], color[2], 255],
|
| 233 |
-
..Default::default()
|
| 234 |
-
});
|
| 235 |
-
}
|
| 236 |
-
}
|
| 237 |
|
| 238 |
let mut guard = state.state.write().await;
|
| 239 |
let document = guard
|
| 240 |
.documents
|
| 241 |
.get_mut(payload.index)
|
| 242 |
.ok_or_else(|| anyhow::anyhow!("Document not found"))?;
|
| 243 |
-
*document =
|
| 244 |
Ok(Json(()))
|
| 245 |
}
|
| 246 |
|
|
@@ -249,7 +219,7 @@ pub async fn ocr(
|
|
| 249 |
State(state): State<AppResources>,
|
| 250 |
Json(payload): Json<IndexPayload>,
|
| 251 |
) -> Result<Json<()>> {
|
| 252 |
-
let snapshot = {
|
| 253 |
let guard = state.state.read().await;
|
| 254 |
guard
|
| 255 |
.documents
|
|
@@ -258,16 +228,14 @@ pub async fn ocr(
|
|
| 258 |
.ok_or_else(|| anyhow::anyhow!("Document not found"))?
|
| 259 |
};
|
| 260 |
|
| 261 |
-
|
| 262 |
-
let mut updated = snapshot;
|
| 263 |
-
updated.text_blocks = text_blocks;
|
| 264 |
|
| 265 |
let mut guard = state.state.write().await;
|
| 266 |
let document = guard
|
| 267 |
.documents
|
| 268 |
.get_mut(payload.index)
|
| 269 |
.ok_or_else(|| anyhow::anyhow!("Document not found"))?;
|
| 270 |
-
*document =
|
| 271 |
Ok(Json(()))
|
| 272 |
}
|
| 273 |
|
|
@@ -276,7 +244,7 @@ pub async fn inpaint(
|
|
| 276 |
State(state): State<AppResources>,
|
| 277 |
Json(payload): Json<IndexPayload>,
|
| 278 |
) -> Result<Json<()>> {
|
| 279 |
-
let snapshot = {
|
| 280 |
let guard = state.state.read().await;
|
| 281 |
guard
|
| 282 |
.documents
|
|
@@ -285,43 +253,14 @@ pub async fn inpaint(
|
|
| 285 |
.ok_or_else(|| anyhow::anyhow!("Document not found"))?
|
| 286 |
};
|
| 287 |
|
| 288 |
-
|
| 289 |
-
.segment
|
| 290 |
-
.as_ref()
|
| 291 |
-
.ok_or_else(|| anyhow::anyhow!("Segment image not found"))?;
|
| 292 |
-
let text_blocks = &snapshot.text_blocks;
|
| 293 |
-
let mut segment_data = segment.to_rgba8();
|
| 294 |
-
let (seg_width, seg_height) = segment_data.dimensions();
|
| 295 |
-
|
| 296 |
-
for y in 0..seg_height {
|
| 297 |
-
for x in 0..seg_width {
|
| 298 |
-
let pixel = segment_data.get_pixel_mut(x, y);
|
| 299 |
-
if pixel.0 != [0, 0, 0, 255] {
|
| 300 |
-
let inside_any_block = text_blocks.iter().any(|block| {
|
| 301 |
-
x >= block.x as u32
|
| 302 |
-
&& x < (block.x + block.width) as u32
|
| 303 |
-
&& y >= block.y as u32
|
| 304 |
-
&& y < (block.y + block.height) as u32
|
| 305 |
-
});
|
| 306 |
-
if !inside_any_block {
|
| 307 |
-
*pixel = image::Rgba([0, 0, 0, 255]);
|
| 308 |
-
}
|
| 309 |
-
}
|
| 310 |
-
}
|
| 311 |
-
}
|
| 312 |
-
|
| 313 |
-
let mask = SerializableDynamicImage::from(image::DynamicImage::ImageRgba8(segment_data));
|
| 314 |
-
let inpainted = state.ml.inpaint(&snapshot.image, &mask).await?;
|
| 315 |
-
|
| 316 |
-
let mut updated = snapshot;
|
| 317 |
-
updated.inpainted = Some(inpainted);
|
| 318 |
|
| 319 |
let mut guard = state.state.write().await;
|
| 320 |
let document = guard
|
| 321 |
.documents
|
| 322 |
.get_mut(payload.index)
|
| 323 |
.ok_or_else(|| anyhow::anyhow!("Document not found"))?;
|
| 324 |
-
*document =
|
| 325 |
Ok(Json(()))
|
| 326 |
}
|
| 327 |
|
|
@@ -544,7 +483,7 @@ pub async fn inpaint_partial(
|
|
| 544 |
SerializableDynamicImage(snapshot.image.crop_imm(x0, y0, crop_width, crop_height));
|
| 545 |
let mask_crop = SerializableDynamicImage(mask_image.crop_imm(x0, y0, crop_width, crop_height));
|
| 546 |
|
| 547 |
-
let inpainted_crop = state.ml.
|
| 548 |
|
| 549 |
let mut stitched = snapshot
|
| 550 |
.inpainted
|
|
@@ -849,3 +788,59 @@ pub async fn download_progress()
|
|
| 849 |
});
|
| 850 |
Sse::new(stream).keep_alive(KeepAlive::default())
|
| 851 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
khr::{deserialize_khr, has_khr_magic, serialize_khr},
|
| 28 |
llm,
|
| 29 |
result::Result,
|
| 30 |
+
state::{Document, TextBlock},
|
| 31 |
version,
|
| 32 |
};
|
| 33 |
|
|
|
|
| 194 |
State(state): State<AppResources>,
|
| 195 |
Json(payload): Json<IndexPayload>,
|
| 196 |
) -> Result<Json<()>> {
|
| 197 |
+
let mut snapshot = {
|
| 198 |
let guard = state.state.read().await;
|
| 199 |
guard
|
| 200 |
.documents
|
|
|
|
| 203 |
.ok_or_else(|| anyhow::anyhow!("Document not found"))?
|
| 204 |
};
|
| 205 |
|
| 206 |
+
state.ml.detect(&mut snapshot).await?;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
|
| 208 |
let mut guard = state.state.write().await;
|
| 209 |
let document = guard
|
| 210 |
.documents
|
| 211 |
.get_mut(payload.index)
|
| 212 |
.ok_or_else(|| anyhow::anyhow!("Document not found"))?;
|
| 213 |
+
*document = snapshot;
|
| 214 |
Ok(Json(()))
|
| 215 |
}
|
| 216 |
|
|
|
|
| 219 |
State(state): State<AppResources>,
|
| 220 |
Json(payload): Json<IndexPayload>,
|
| 221 |
) -> Result<Json<()>> {
|
| 222 |
+
let mut snapshot = {
|
| 223 |
let guard = state.state.read().await;
|
| 224 |
guard
|
| 225 |
.documents
|
|
|
|
| 228 |
.ok_or_else(|| anyhow::anyhow!("Document not found"))?
|
| 229 |
};
|
| 230 |
|
| 231 |
+
state.ml.ocr(&mut snapshot).await?;
|
|
|
|
|
|
|
| 232 |
|
| 233 |
let mut guard = state.state.write().await;
|
| 234 |
let document = guard
|
| 235 |
.documents
|
| 236 |
.get_mut(payload.index)
|
| 237 |
.ok_or_else(|| anyhow::anyhow!("Document not found"))?;
|
| 238 |
+
*document = snapshot;
|
| 239 |
Ok(Json(()))
|
| 240 |
}
|
| 241 |
|
|
|
|
| 244 |
State(state): State<AppResources>,
|
| 245 |
Json(payload): Json<IndexPayload>,
|
| 246 |
) -> Result<Json<()>> {
|
| 247 |
+
let mut snapshot = {
|
| 248 |
let guard = state.state.read().await;
|
| 249 |
guard
|
| 250 |
.documents
|
|
|
|
| 253 |
.ok_or_else(|| anyhow::anyhow!("Document not found"))?
|
| 254 |
};
|
| 255 |
|
| 256 |
+
state.ml.inpaint(&mut snapshot).await?;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
|
| 258 |
let mut guard = state.state.write().await;
|
| 259 |
let document = guard
|
| 260 |
.documents
|
| 261 |
.get_mut(payload.index)
|
| 262 |
.ok_or_else(|| anyhow::anyhow!("Document not found"))?;
|
| 263 |
+
*document = snapshot;
|
| 264 |
Ok(Json(()))
|
| 265 |
}
|
| 266 |
|
|
|
|
| 483 |
SerializableDynamicImage(snapshot.image.crop_imm(x0, y0, crop_width, crop_height));
|
| 484 |
let mask_crop = SerializableDynamicImage(mask_image.crop_imm(x0, y0, crop_width, crop_height));
|
| 485 |
|
| 486 |
+
let inpainted_crop = state.ml.inpaint_raw(&image_crop, &mask_crop).await?;
|
| 487 |
|
| 488 |
let mut stitched = snapshot
|
| 489 |
.inpainted
|
|
|
|
| 788 |
});
|
| 789 |
Sse::new(stream).keep_alive(KeepAlive::default())
|
| 790 |
}
|
| 791 |
+
|
| 792 |
+
// --- Auto-processing pipeline endpoints ---
|
| 793 |
+
|
| 794 |
+
/// Start the processing pipeline. Returns immediately; progress is available
|
| 795 |
+
/// via the `process_progress` SSE endpoint.
|
| 796 |
+
pub async fn process(
|
| 797 |
+
State(state): State<AppResources>,
|
| 798 |
+
Json(payload): Json<crate::pipeline::ProcessRequest>,
|
| 799 |
+
) -> Result<Json<()>> {
|
| 800 |
+
{
|
| 801 |
+
let guard = state.pipeline.read().await;
|
| 802 |
+
if guard.is_some() {
|
| 803 |
+
Err(anyhow::anyhow!("A processing pipeline is already running"))?;
|
| 804 |
+
}
|
| 805 |
+
}
|
| 806 |
+
|
| 807 |
+
let cancel = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
|
| 808 |
+
{
|
| 809 |
+
let mut guard = state.pipeline.write().await;
|
| 810 |
+
*guard = Some(crate::pipeline::PipelineHandle {
|
| 811 |
+
cancel: cancel.clone(),
|
| 812 |
+
});
|
| 813 |
+
}
|
| 814 |
+
|
| 815 |
+
let resources = state.clone();
|
| 816 |
+
tokio::spawn(async move {
|
| 817 |
+
crate::pipeline::run_pipeline(resources, payload, cancel).await;
|
| 818 |
+
});
|
| 819 |
+
|
| 820 |
+
Ok(Json(()))
|
| 821 |
+
}
|
| 822 |
+
|
| 823 |
+
pub async fn process_progress()
|
| 824 |
+
-> Sse<impl futures::Stream<Item = std::result::Result<Event, Infallible>>> {
|
| 825 |
+
let rx = crate::pipeline::subscribe();
|
| 826 |
+
let stream = futures::stream::unfold(rx, |mut rx| async {
|
| 827 |
+
loop {
|
| 828 |
+
match rx.recv().await {
|
| 829 |
+
Ok(p) => return Some((Ok(Event::default().json_data(p).unwrap()), rx)),
|
| 830 |
+
Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue,
|
| 831 |
+
Err(tokio::sync::broadcast::error::RecvError::Closed) => return None,
|
| 832 |
+
}
|
| 833 |
+
}
|
| 834 |
+
});
|
| 835 |
+
Sse::new(stream).keep_alive(KeepAlive::default())
|
| 836 |
+
}
|
| 837 |
+
|
| 838 |
+
pub async fn process_cancel(State(state): State<AppResources>) -> Result<Json<()>> {
|
| 839 |
+
let guard = state.pipeline.read().await;
|
| 840 |
+
if let Some(handle) = guard.as_ref() {
|
| 841 |
+
handle
|
| 842 |
+
.cancel
|
| 843 |
+
.store(true, std::sync::atomic::Ordering::Relaxed);
|
| 844 |
+
}
|
| 845 |
+
Ok(Json(()))
|
| 846 |
+
}
|
koharu/src/lib.rs
CHANGED
|
@@ -5,6 +5,7 @@ pub mod image;
|
|
| 5 |
pub mod khr;
|
| 6 |
pub mod llm;
|
| 7 |
pub mod ml;
|
|
|
|
| 8 |
pub mod renderer;
|
| 9 |
pub mod result;
|
| 10 |
pub mod server;
|
|
|
|
| 5 |
pub mod khr;
|
| 6 |
pub mod llm;
|
| 7 |
pub mod ml;
|
| 8 |
+
pub mod pipeline;
|
| 9 |
pub mod renderer;
|
| 10 |
pub mod result;
|
| 11 |
pub mod server;
|
koharu/src/ml.rs
CHANGED
|
@@ -1,12 +1,12 @@
|
|
| 1 |
use anyhow::Result;
|
| 2 |
-
use image::DynamicImage;
|
| 3 |
use koharu_ml::comic_text_detector::{self, ComicTextDetector};
|
| 4 |
use koharu_ml::font_detector::{self, FontDetector};
|
| 5 |
use koharu_ml::lama::{self, Lama};
|
| 6 |
use koharu_ml::manga_ocr::{self, MangaOcr};
|
| 7 |
|
| 8 |
use crate::image::SerializableDynamicImage;
|
| 9 |
-
use crate::state::TextBlock;
|
| 10 |
|
| 11 |
const NEAR_BLACK_THRESHOLD: u8 = 12;
|
| 12 |
const GRAY_NEAR_BLACK_THRESHOLD: u8 = 60;
|
|
@@ -85,11 +85,10 @@ impl Model {
|
|
| 85 |
})
|
| 86 |
}
|
| 87 |
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
let (bboxes, segment) = self.dialog_detector.inference(image)?;
|
| 93 |
|
| 94 |
let mut text_blocks: Vec<TextBlock> = bboxes
|
| 95 |
.into_iter()
|
|
@@ -109,22 +108,51 @@ impl Model {
|
|
| 109 |
.unwrap_or(std::cmp::Ordering::Equal)
|
| 110 |
});
|
| 111 |
|
| 112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
}
|
| 114 |
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
if blocks.is_empty() {
|
| 121 |
-
return Ok(Vec::new());
|
| 122 |
}
|
| 123 |
|
| 124 |
-
let crops: Vec<DynamicImage> =
|
|
|
|
| 125 |
.iter()
|
| 126 |
.map(|block| {
|
| 127 |
-
image.crop_imm(
|
| 128 |
block.x as u32,
|
| 129 |
block.y as u32,
|
| 130 |
block.width as u32,
|
|
@@ -134,24 +162,54 @@ impl Model {
|
|
| 134 |
.collect();
|
| 135 |
let texts = self.ocr.inference(&crops)?;
|
| 136 |
|
| 137 |
-
|
| 138 |
-
.
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
}
|
| 147 |
|
| 148 |
-
|
|
|
|
| 149 |
&self,
|
| 150 |
image: &SerializableDynamicImage,
|
| 151 |
mask: &SerializableDynamicImage,
|
| 152 |
) -> Result<SerializableDynamicImage> {
|
| 153 |
let result = self.lama.inference(image, mask)?;
|
| 154 |
-
|
| 155 |
Ok(result.into())
|
| 156 |
}
|
| 157 |
|
|
|
|
| 1 |
use anyhow::Result;
|
| 2 |
+
use image::{DynamicImage, Rgba};
|
| 3 |
use koharu_ml::comic_text_detector::{self, ComicTextDetector};
|
| 4 |
use koharu_ml::font_detector::{self, FontDetector};
|
| 5 |
use koharu_ml::lama::{self, Lama};
|
| 6 |
use koharu_ml::manga_ocr::{self, MangaOcr};
|
| 7 |
|
| 8 |
use crate::image::SerializableDynamicImage;
|
| 9 |
+
use crate::state::{Document, TextBlock, TextStyle};
|
| 10 |
|
| 11 |
const NEAR_BLACK_THRESHOLD: u8 = 12;
|
| 12 |
const GRAY_NEAR_BLACK_THRESHOLD: u8 = 60;
|
|
|
|
| 85 |
})
|
| 86 |
}
|
| 87 |
|
| 88 |
+
/// Detect text blocks and fonts in a document.
|
| 89 |
+
/// Sets `doc.text_blocks` (with font predictions/styles) and `doc.segment`.
|
| 90 |
+
pub async fn detect(&self, doc: &mut Document) -> Result<()> {
|
| 91 |
+
let (bboxes, segment) = self.dialog_detector.inference(&doc.image)?;
|
|
|
|
| 92 |
|
| 93 |
let mut text_blocks: Vec<TextBlock> = bboxes
|
| 94 |
.into_iter()
|
|
|
|
| 108 |
.unwrap_or(std::cmp::Ordering::Equal)
|
| 109 |
});
|
| 110 |
|
| 111 |
+
doc.text_blocks = text_blocks;
|
| 112 |
+
doc.segment = Some(DynamicImage::ImageLuma8(segment).into());
|
| 113 |
+
|
| 114 |
+
if !doc.text_blocks.is_empty() {
|
| 115 |
+
let images: Vec<DynamicImage> = doc
|
| 116 |
+
.text_blocks
|
| 117 |
+
.iter()
|
| 118 |
+
.map(|block| {
|
| 119 |
+
doc.image.crop_imm(
|
| 120 |
+
block.x as u32,
|
| 121 |
+
block.y as u32,
|
| 122 |
+
block.width as u32,
|
| 123 |
+
block.height as u32,
|
| 124 |
+
)
|
| 125 |
+
})
|
| 126 |
+
.collect();
|
| 127 |
+
|
| 128 |
+
let font_predictions = self.detect_fonts(&images, 1).await?;
|
| 129 |
+
for (block, prediction) in doc.text_blocks.iter_mut().zip(font_predictions) {
|
| 130 |
+
let color = prediction.text_color;
|
| 131 |
+
let font_size = (prediction.font_size_px > 0.0).then_some(prediction.font_size_px);
|
| 132 |
+
block.font_prediction = Some(prediction);
|
| 133 |
+
block.style = Some(TextStyle {
|
| 134 |
+
font_size,
|
| 135 |
+
color: [color[0], color[1], color[2], 255],
|
| 136 |
+
..Default::default()
|
| 137 |
+
});
|
| 138 |
+
}
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
Ok(())
|
| 142 |
}
|
| 143 |
|
| 144 |
+
/// Run OCR on all text blocks in the document.
|
| 145 |
+
/// Updates `doc.text_blocks` with recognized text.
|
| 146 |
+
pub async fn ocr(&self, doc: &mut Document) -> Result<()> {
|
| 147 |
+
if doc.text_blocks.is_empty() {
|
| 148 |
+
return Ok(());
|
|
|
|
|
|
|
| 149 |
}
|
| 150 |
|
| 151 |
+
let crops: Vec<DynamicImage> = doc
|
| 152 |
+
.text_blocks
|
| 153 |
.iter()
|
| 154 |
.map(|block| {
|
| 155 |
+
doc.image.crop_imm(
|
| 156 |
block.x as u32,
|
| 157 |
block.y as u32,
|
| 158 |
block.width as u32,
|
|
|
|
| 162 |
.collect();
|
| 163 |
let texts = self.ocr.inference(&crops)?;
|
| 164 |
|
| 165 |
+
for (block, text) in doc.text_blocks.iter_mut().zip(texts) {
|
| 166 |
+
block.text = text.into();
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
Ok(())
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
/// Inpaint text regions in the document.
|
| 173 |
+
/// Builds mask from `doc.segment` + `doc.text_blocks`, sets `doc.inpainted`.
|
| 174 |
+
pub async fn inpaint(&self, doc: &mut Document) -> Result<()> {
|
| 175 |
+
let segment = doc
|
| 176 |
+
.segment
|
| 177 |
+
.as_ref()
|
| 178 |
+
.ok_or_else(|| anyhow::anyhow!("Segment image not found"))?;
|
| 179 |
+
let mut segment_data = segment.to_rgba8();
|
| 180 |
+
let (seg_width, seg_height) = segment_data.dimensions();
|
| 181 |
+
|
| 182 |
+
for y in 0..seg_height {
|
| 183 |
+
for x in 0..seg_width {
|
| 184 |
+
let pixel = segment_data.get_pixel_mut(x, y);
|
| 185 |
+
if pixel.0 != [0, 0, 0, 255] {
|
| 186 |
+
let inside_any_block = doc.text_blocks.iter().any(|block| {
|
| 187 |
+
x >= block.x as u32
|
| 188 |
+
&& x < (block.x + block.width) as u32
|
| 189 |
+
&& y >= block.y as u32
|
| 190 |
+
&& y < (block.y + block.height) as u32
|
| 191 |
+
});
|
| 192 |
+
if !inside_any_block {
|
| 193 |
+
*pixel = Rgba([0, 0, 0, 255]);
|
| 194 |
+
}
|
| 195 |
+
}
|
| 196 |
+
}
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
let mask = SerializableDynamicImage::from(DynamicImage::ImageRgba8(segment_data));
|
| 200 |
+
let result = self.lama.inference(&doc.image, &mask)?;
|
| 201 |
+
doc.inpainted = Some(result.into());
|
| 202 |
+
|
| 203 |
+
Ok(())
|
| 204 |
}
|
| 205 |
|
| 206 |
+
/// Low-level inpaint: inpaint a specific image region with a mask.
|
| 207 |
+
pub async fn inpaint_raw(
|
| 208 |
&self,
|
| 209 |
image: &SerializableDynamicImage,
|
| 210 |
mask: &SerializableDynamicImage,
|
| 211 |
) -> Result<SerializableDynamicImage> {
|
| 212 |
let result = self.lama.inference(image, mask)?;
|
|
|
|
| 213 |
Ok(result.into())
|
| 214 |
}
|
| 215 |
|
koharu/src/pipeline.rs
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
use std::sync::{
|
| 2 |
+
Arc,
|
| 3 |
+
atomic::{AtomicBool, Ordering},
|
| 4 |
+
};
|
| 5 |
+
use std::time::Duration;
|
| 6 |
+
|
| 7 |
+
use koharu_ml::llm::ModelId;
|
| 8 |
+
use koharu_renderer::renderer::TextShaderEffect;
|
| 9 |
+
use once_cell::sync::Lazy;
|
| 10 |
+
use serde::{Deserialize, Serialize};
|
| 11 |
+
use std::str::FromStr;
|
| 12 |
+
use tokio::sync::broadcast;
|
| 13 |
+
|
| 14 |
+
use crate::app::AppResources;
|
| 15 |
+
|
| 16 |
+
/// Steps in the processing pipeline.
|
| 17 |
+
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
|
| 18 |
+
#[serde(rename_all = "camelCase")]
|
| 19 |
+
pub enum PipelineStep {
|
| 20 |
+
Detect,
|
| 21 |
+
Ocr,
|
| 22 |
+
Inpaint,
|
| 23 |
+
LlmGenerate,
|
| 24 |
+
Render,
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
impl PipelineStep {
|
| 28 |
+
pub const ALL: &[PipelineStep] = &[
|
| 29 |
+
PipelineStep::Detect,
|
| 30 |
+
PipelineStep::Ocr,
|
| 31 |
+
PipelineStep::Inpaint,
|
| 32 |
+
PipelineStep::LlmGenerate,
|
| 33 |
+
PipelineStep::Render,
|
| 34 |
+
];
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
/// Status of the pipeline.
|
| 38 |
+
#[derive(Debug, Clone, Serialize, PartialEq)]
|
| 39 |
+
#[serde(rename_all = "camelCase")]
|
| 40 |
+
pub enum PipelineStatus {
|
| 41 |
+
Running,
|
| 42 |
+
Completed,
|
| 43 |
+
Cancelled,
|
| 44 |
+
Failed(String),
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
/// SSE event payload sent to the frontend.
|
| 48 |
+
#[derive(Debug, Clone, Serialize)]
|
| 49 |
+
#[serde(rename_all = "camelCase")]
|
| 50 |
+
pub struct PipelineProgress {
|
| 51 |
+
pub status: PipelineStatus,
|
| 52 |
+
pub step: Option<PipelineStep>,
|
| 53 |
+
pub current_document: usize,
|
| 54 |
+
pub total_documents: usize,
|
| 55 |
+
pub current_step_index: usize,
|
| 56 |
+
pub total_steps: usize,
|
| 57 |
+
pub overall_percent: u8,
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
/// Request payload for the auto-process endpoint.
|
| 61 |
+
#[derive(Debug, Deserialize)]
|
| 62 |
+
#[serde(rename_all = "camelCase")]
|
| 63 |
+
pub struct ProcessRequest {
|
| 64 |
+
/// None means all documents; Some(i) means single document.
|
| 65 |
+
pub index: Option<usize>,
|
| 66 |
+
/// LLM model id to use (will load if not already loaded).
|
| 67 |
+
pub llm_model_id: Option<String>,
|
| 68 |
+
/// Target language for translation.
|
| 69 |
+
pub language: Option<String>,
|
| 70 |
+
/// Shader effect for rendering.
|
| 71 |
+
pub shader_effect: Option<TextShaderEffect>,
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
/// Handle to a running pipeline, used for cancellation.
|
| 75 |
+
pub struct PipelineHandle {
|
| 76 |
+
pub cancel: Arc<AtomicBool>,
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
// Global broadcast channel for pipeline progress (mirrors download.rs pattern).
|
| 80 |
+
static PIPELINE_TX: Lazy<broadcast::Sender<PipelineProgress>> =
|
| 81 |
+
Lazy::new(|| broadcast::channel(256).0);
|
| 82 |
+
|
| 83 |
+
pub fn subscribe() -> broadcast::Receiver<PipelineProgress> {
|
| 84 |
+
PIPELINE_TX.subscribe()
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
fn emit(progress: PipelineProgress) {
|
| 88 |
+
let _ = PIPELINE_TX.send(progress);
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
fn compute_percent(doc: usize, step: usize, total_docs: usize, total_steps: usize) -> u8 {
|
| 92 |
+
let done_units = doc * total_steps + step;
|
| 93 |
+
let total_units = total_docs * total_steps;
|
| 94 |
+
if total_units == 0 {
|
| 95 |
+
return 0;
|
| 96 |
+
}
|
| 97 |
+
((done_units as f64 / total_units as f64) * 100.0).round() as u8
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
/// Run the processing pipeline. Called from a spawned task.
|
| 101 |
+
pub async fn run_pipeline(
|
| 102 |
+
resources: AppResources,
|
| 103 |
+
request: ProcessRequest,
|
| 104 |
+
cancel: Arc<AtomicBool>,
|
| 105 |
+
) {
|
| 106 |
+
let result = run_pipeline_inner(&resources, &request, &cancel).await;
|
| 107 |
+
|
| 108 |
+
let total_docs = match &request.index {
|
| 109 |
+
Some(_) => 1,
|
| 110 |
+
None => resources.state.read().await.documents.len(),
|
| 111 |
+
};
|
| 112 |
+
|
| 113 |
+
match result {
|
| 114 |
+
Ok(()) if cancel.load(Ordering::Relaxed) => {
|
| 115 |
+
emit(PipelineProgress {
|
| 116 |
+
status: PipelineStatus::Cancelled,
|
| 117 |
+
step: None,
|
| 118 |
+
current_document: total_docs,
|
| 119 |
+
total_documents: total_docs,
|
| 120 |
+
current_step_index: 0,
|
| 121 |
+
total_steps: PipelineStep::ALL.len(),
|
| 122 |
+
overall_percent: 0,
|
| 123 |
+
});
|
| 124 |
+
}
|
| 125 |
+
Ok(()) => {
|
| 126 |
+
emit(PipelineProgress {
|
| 127 |
+
status: PipelineStatus::Completed,
|
| 128 |
+
step: None,
|
| 129 |
+
current_document: total_docs,
|
| 130 |
+
total_documents: total_docs,
|
| 131 |
+
current_step_index: PipelineStep::ALL.len(),
|
| 132 |
+
total_steps: PipelineStep::ALL.len(),
|
| 133 |
+
overall_percent: 100,
|
| 134 |
+
});
|
| 135 |
+
}
|
| 136 |
+
Err(e) => {
|
| 137 |
+
tracing::error!("Pipeline failed: {e:#}");
|
| 138 |
+
emit(PipelineProgress {
|
| 139 |
+
status: PipelineStatus::Failed(e.to_string()),
|
| 140 |
+
step: None,
|
| 141 |
+
current_document: 0,
|
| 142 |
+
total_documents: total_docs,
|
| 143 |
+
current_step_index: 0,
|
| 144 |
+
total_steps: PipelineStep::ALL.len(),
|
| 145 |
+
overall_percent: 0,
|
| 146 |
+
});
|
| 147 |
+
}
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
// Clear the pipeline handle
|
| 151 |
+
let mut guard = resources.pipeline.write().await;
|
| 152 |
+
*guard = None;
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
async fn run_pipeline_inner(
|
| 156 |
+
res: &AppResources,
|
| 157 |
+
req: &ProcessRequest,
|
| 158 |
+
cancel: &Arc<AtomicBool>,
|
| 159 |
+
) -> anyhow::Result<()> {
|
| 160 |
+
let total_docs = {
|
| 161 |
+
let guard = res.state.read().await;
|
| 162 |
+
let len = guard.documents.len();
|
| 163 |
+
match req.index {
|
| 164 |
+
Some(i) if i >= len => {
|
| 165 |
+
anyhow::bail!("Document index {} out of range (have {})", i, len);
|
| 166 |
+
}
|
| 167 |
+
Some(_) => 1,
|
| 168 |
+
None => len,
|
| 169 |
+
}
|
| 170 |
+
};
|
| 171 |
+
|
| 172 |
+
if total_docs == 0 {
|
| 173 |
+
return Ok(());
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
// Ensure LLM is loaded
|
| 177 |
+
if let Some(model_id) = &req.llm_model_id {
|
| 178 |
+
if !res.llm.ready().await {
|
| 179 |
+
let id = ModelId::from_str(model_id)?;
|
| 180 |
+
res.llm.load(id).await;
|
| 181 |
+
// Poll until ready (with timeout)
|
| 182 |
+
for _ in 0..300 {
|
| 183 |
+
if res.llm.ready().await {
|
| 184 |
+
break;
|
| 185 |
+
}
|
| 186 |
+
tokio::time::sleep(Duration::from_millis(100)).await;
|
| 187 |
+
if cancel.load(Ordering::Relaxed) {
|
| 188 |
+
return Ok(());
|
| 189 |
+
}
|
| 190 |
+
}
|
| 191 |
+
if !res.llm.ready().await {
|
| 192 |
+
anyhow::bail!("LLM failed to load within timeout");
|
| 193 |
+
}
|
| 194 |
+
}
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
if let Some(locale) = req.language.as_ref() {
|
| 198 |
+
koharu_ml::set_locale(locale.clone());
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
let start_index = req.index.unwrap_or(0);
|
| 202 |
+
let end_index = req.index.map(|i| i + 1).unwrap_or(total_docs);
|
| 203 |
+
let total_steps = PipelineStep::ALL.len();
|
| 204 |
+
|
| 205 |
+
for (doc_ordinal, doc_index) in (start_index..end_index).enumerate() {
|
| 206 |
+
for (step_ordinal, step) in PipelineStep::ALL.iter().enumerate() {
|
| 207 |
+
if cancel.load(Ordering::Relaxed) {
|
| 208 |
+
return Ok(());
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
let overall = compute_percent(doc_ordinal, step_ordinal, total_docs, total_steps);
|
| 212 |
+
emit(PipelineProgress {
|
| 213 |
+
status: PipelineStatus::Running,
|
| 214 |
+
step: Some(*step),
|
| 215 |
+
current_document: doc_ordinal,
|
| 216 |
+
total_documents: total_docs,
|
| 217 |
+
current_step_index: step_ordinal,
|
| 218 |
+
total_steps,
|
| 219 |
+
overall_percent: overall,
|
| 220 |
+
});
|
| 221 |
+
|
| 222 |
+
// Give the runtime a chance to flush the SSE event before blocking ML work
|
| 223 |
+
tokio::task::yield_now().await;
|
| 224 |
+
tokio::time::sleep(Duration::from_millis(1)).await;
|
| 225 |
+
|
| 226 |
+
// Snapshot → process → write back
|
| 227 |
+
let mut snapshot =
|
| 228 |
+
{
|
| 229 |
+
let guard = res.state.read().await;
|
| 230 |
+
guard.documents.get(doc_index).cloned().ok_or_else(|| {
|
| 231 |
+
anyhow::anyhow!("Document not found at index {}", doc_index)
|
| 232 |
+
})?
|
| 233 |
+
};
|
| 234 |
+
|
| 235 |
+
match step {
|
| 236 |
+
PipelineStep::Detect => res.ml.detect(&mut snapshot).await?,
|
| 237 |
+
PipelineStep::Ocr => res.ml.ocr(&mut snapshot).await?,
|
| 238 |
+
PipelineStep::Inpaint => res.ml.inpaint(&mut snapshot).await?,
|
| 239 |
+
PipelineStep::LlmGenerate => {
|
| 240 |
+
res.llm.generate(&mut snapshot).await?;
|
| 241 |
+
}
|
| 242 |
+
PipelineStep::Render => {
|
| 243 |
+
res.renderer.render(
|
| 244 |
+
&mut snapshot,
|
| 245 |
+
None,
|
| 246 |
+
req.shader_effect.unwrap_or_default(),
|
| 247 |
+
)?;
|
| 248 |
+
}
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
let mut guard = res.state.write().await;
|
| 252 |
+
let document = guard
|
| 253 |
+
.documents
|
| 254 |
+
.get_mut(doc_index)
|
| 255 |
+
.ok_or_else(|| anyhow::anyhow!("Document not found at index {}", doc_index))?;
|
| 256 |
+
*document = snapshot;
|
| 257 |
+
}
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
Ok(())
|
| 261 |
+
}
|
koharu/src/server.rs
CHANGED
|
@@ -76,6 +76,9 @@ fn build_router(state: AppResources) -> Router {
|
|
| 76 |
.route("/api/llm_ready", get(llm_ready).post(llm_ready))
|
| 77 |
.route("/api/llm_generate", post(llm_generate))
|
| 78 |
.route("/api/download_progress", get(download_progress))
|
|
|
|
|
|
|
|
|
|
| 79 |
.with_state(state)
|
| 80 |
.layer(DefaultBodyLimit::max(1024 * 1024 * 1024))
|
| 81 |
.layer(
|
|
|
|
| 76 |
.route("/api/llm_ready", get(llm_ready).post(llm_ready))
|
| 77 |
.route("/api/llm_generate", post(llm_generate))
|
| 78 |
.route("/api/download_progress", get(download_progress))
|
| 79 |
+
.route("/api/process", post(process))
|
| 80 |
+
.route("/api/process_cancel", post(process_cancel))
|
| 81 |
+
.route("/api/process_progress", get(process_progress))
|
| 82 |
.with_state(state)
|
| 83 |
.layer(DefaultBodyLimit::max(1024 * 1024 * 1024))
|
| 84 |
.layer(
|
scripts/dev.ts
CHANGED
|
@@ -58,7 +58,7 @@ async function setupCuda() {
|
|
| 58 |
}
|
| 59 |
|
| 60 |
throw new Error(
|
| 61 |
-
'NVCC not found. Please install the CUDA Toolkit from https://developer.nvidia.com/cuda-
|
| 62 |
)
|
| 63 |
}
|
| 64 |
|
|
|
|
| 58 |
}
|
| 59 |
|
| 60 |
throw new Error(
|
| 61 |
+
'NVCC not found. Please install the CUDA Toolkit from https://developer.nvidia.com/cuda-downloads',
|
| 62 |
)
|
| 63 |
}
|
| 64 |
|
ui/components/ActivityBubble.tsx
CHANGED
|
@@ -28,7 +28,7 @@ function ProgressBar({ percent }: { percent?: number }) {
|
|
| 28 |
<div className='bg-muted relative h-1.5 flex-1 overflow-hidden rounded-full'>
|
| 29 |
{typeof percent === 'number' ? (
|
| 30 |
<div
|
| 31 |
-
className='bg-primary h-full rounded-full transition-[width] duration-
|
| 32 |
style={{ width: `${percent}%` }}
|
| 33 |
/>
|
| 34 |
) : (
|
|
|
|
| 28 |
<div className='bg-muted relative h-1.5 flex-1 overflow-hidden rounded-full'>
|
| 29 |
{typeof percent === 'number' ? (
|
| 30 |
<div
|
| 31 |
+
className='bg-primary h-full rounded-full transition-[width] duration-700 ease-out'
|
| 32 |
style={{ width: `${percent}%` }}
|
| 33 |
/>
|
| 34 |
) : (
|
ui/lib/backend.ts
CHANGED
|
@@ -231,38 +231,43 @@ export type DownloadProgress = {
|
|
| 231 |
status: 'Started' | 'Downloading' | 'Completed' | { Failed: string }
|
| 232 |
}
|
| 233 |
|
| 234 |
-
export
|
| 235 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
): () => void {
|
| 237 |
-
let
|
| 238 |
|
| 239 |
;(async () => {
|
| 240 |
await ensureInitialized()
|
| 241 |
-
|
| 242 |
es.onmessage = (event) => {
|
| 243 |
-
if (stopped) {
|
| 244 |
-
es.close()
|
| 245 |
-
return
|
| 246 |
-
}
|
| 247 |
try {
|
| 248 |
-
callback(JSON.parse(event.data) as
|
| 249 |
} catch (_) {}
|
| 250 |
}
|
| 251 |
-
const check = () => {
|
| 252 |
-
if (stopped) es.close()
|
| 253 |
-
}
|
| 254 |
-
const id = setInterval(check, 1000)
|
| 255 |
-
es.onerror = () => {
|
| 256 |
-
clearInterval(id)
|
| 257 |
-
es.close()
|
| 258 |
-
}
|
| 259 |
})()
|
| 260 |
|
| 261 |
return () => {
|
| 262 |
-
|
| 263 |
}
|
| 264 |
}
|
| 265 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
export const isTauri = isTauriEnv
|
| 267 |
|
| 268 |
export const isMacOS = (): boolean => {
|
|
|
|
| 231 |
status: 'Started' | 'Downloading' | 'Completed' | { Failed: string }
|
| 232 |
}
|
| 233 |
|
| 234 |
+
export type ProcessProgress = {
|
| 235 |
+
status: 'running' | 'completed' | 'cancelled' | { failed: string }
|
| 236 |
+
step: string | null
|
| 237 |
+
currentDocument: number
|
| 238 |
+
totalDocuments: number
|
| 239 |
+
currentStepIndex: number
|
| 240 |
+
totalSteps: number
|
| 241 |
+
overallPercent: number
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
function subscribeSSE<T>(
|
| 245 |
+
endpoint: string,
|
| 246 |
+
callback: (data: T) => void,
|
| 247 |
): () => void {
|
| 248 |
+
let es: EventSource | null = null
|
| 249 |
|
| 250 |
;(async () => {
|
| 251 |
await ensureInitialized()
|
| 252 |
+
es = new EventSource(`${apiBase}/${endpoint}`)
|
| 253 |
es.onmessage = (event) => {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
try {
|
| 255 |
+
callback(JSON.parse(event.data) as T)
|
| 256 |
} catch (_) {}
|
| 257 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
})()
|
| 259 |
|
| 260 |
return () => {
|
| 261 |
+
es?.close()
|
| 262 |
}
|
| 263 |
}
|
| 264 |
|
| 265 |
+
export const subscribeDownloadProgress = (cb: (p: DownloadProgress) => void) =>
|
| 266 |
+
subscribeSSE<DownloadProgress>('download_progress', cb)
|
| 267 |
+
|
| 268 |
+
export const subscribeProcessProgress = (cb: (p: ProcessProgress) => void) =>
|
| 269 |
+
subscribeSSE<ProcessProgress>('process_progress', cb)
|
| 270 |
+
|
| 271 |
export const isTauri = isTauriEnv
|
| 272 |
|
| 273 |
export const isMacOS = (): boolean => {
|
ui/lib/operations.ts
CHANGED
|
@@ -41,10 +41,15 @@ export const createOperationSlice = (set: any): OperationSlice => ({
|
|
| 41 |
: { operation: undefined },
|
| 42 |
),
|
| 43 |
finishOperation: () => set({ operation: undefined }),
|
| 44 |
-
cancelOperation: () =>
|
| 45 |
set((state: OperationSlice) =>
|
| 46 |
state.operation
|
| 47 |
? { operation: { ...state.operation, cancelRequested: true } }
|
| 48 |
: { operation: undefined },
|
| 49 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
})
|
|
|
|
| 41 |
: { operation: undefined },
|
| 42 |
),
|
| 43 |
finishOperation: () => set({ operation: undefined }),
|
| 44 |
+
cancelOperation: () => {
|
| 45 |
set((state: OperationSlice) =>
|
| 46 |
state.operation
|
| 47 |
? { operation: { ...state.operation, cancelRequested: true } }
|
| 48 |
: { operation: undefined },
|
| 49 |
+
)
|
| 50 |
+
// Also cancel backend pipeline if running
|
| 51 |
+
import('@/lib/backend').then(({ invoke }) => {
|
| 52 |
+
invoke('process_cancel').catch(() => {})
|
| 53 |
+
})
|
| 54 |
+
},
|
| 55 |
})
|
ui/lib/store.ts
CHANGED
|
@@ -1,7 +1,13 @@
|
|
| 1 |
'use client'
|
| 2 |
|
| 3 |
import { create } from 'zustand'
|
| 4 |
-
import {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
import {
|
| 6 |
Document,
|
| 7 |
InpaintRegion,
|
|
@@ -15,18 +21,6 @@ type LlmModelInfo = {
|
|
| 15 |
languages: string[]
|
| 16 |
}
|
| 17 |
|
| 18 |
-
type ProcessAction = 'detect' | 'ocr' | 'inpaint' | 'llmGenerate' | 'render'
|
| 19 |
-
|
| 20 |
-
type ProcessImageOptionsObject = {
|
| 21 |
-
onProgress?: (progress: number) => Promise<void>
|
| 22 |
-
onStepChange?: (step: ProcessAction) => Promise<void> | void
|
| 23 |
-
skipOperationTracking?: boolean
|
| 24 |
-
}
|
| 25 |
-
|
| 26 |
-
type ProcessImageOptions =
|
| 27 |
-
| ((progress: number) => Promise<void>)
|
| 28 |
-
| ProcessImageOptionsObject
|
| 29 |
-
|
| 30 |
const createTextBlockSyncer = () => {
|
| 31 |
let pending: {
|
| 32 |
index: number
|
|
@@ -211,11 +205,7 @@ type AppState = OperationSlice & {
|
|
| 211 |
index?: number,
|
| 212 |
textBlockIndex?: number,
|
| 213 |
) => Promise<void>
|
| 214 |
-
processImage: (
|
| 215 |
-
_?: any,
|
| 216 |
-
index?: number,
|
| 217 |
-
options?: ProcessImageOptions,
|
| 218 |
-
) => Promise<void>
|
| 219 |
inpaintAndRenderImage: (_?: any, index?: number) => Promise<void>
|
| 220 |
processAllImages: () => Promise<void>
|
| 221 |
exportDocument: () => Promise<void>
|
|
@@ -619,98 +609,27 @@ export const useAppStore = create<AppState>((set, get) => {
|
|
| 619 |
void get().renderTextBlock(undefined, index, textBlockIndex)
|
| 620 |
}
|
| 621 |
},
|
| 622 |
-
//
|
| 623 |
-
processImage: async (_, index
|
| 624 |
-
const normalizedOptions: ProcessImageOptionsObject =
|
| 625 |
-
typeof options === 'function'
|
| 626 |
-
? { onProgress: options, skipOperationTracking: false }
|
| 627 |
-
: (options ?? {})
|
| 628 |
-
|
| 629 |
-
const { onProgress, onStepChange, skipOperationTracking } =
|
| 630 |
-
normalizedOptions
|
| 631 |
-
const operation = get().operation
|
| 632 |
-
const isBatchRun = operation?.type === 'process-all'
|
| 633 |
-
|
| 634 |
-
if (!get().llmReady) {
|
| 635 |
-
await get().llmList()
|
| 636 |
-
await get().llmToggleLoadUnload()
|
| 637 |
-
}
|
| 638 |
-
|
| 639 |
index = index ?? get().currentDocumentIndex
|
| 640 |
-
|
| 641 |
-
|
| 642 |
-
|
| 643 |
-
|
| 644 |
-
|
| 645 |
-
|
| 646 |
-
|
| 647 |
-
'
|
| 648 |
-
|
| 649 |
-
|
| 650 |
-
|
| 651 |
-
|
| 652 |
-
|
| 653 |
-
|
| 654 |
-
|
| 655 |
-
const firstStep = actions[0] ?? 'detect'
|
| 656 |
-
if (ownsOperation) {
|
| 657 |
-
get().startOperation({
|
| 658 |
-
type: 'process-current',
|
| 659 |
-
step: firstStep,
|
| 660 |
-
current: 0,
|
| 661 |
-
total: totalSteps,
|
| 662 |
-
cancellable: true,
|
| 663 |
-
})
|
| 664 |
-
} else {
|
| 665 |
-
get().updateOperation({
|
| 666 |
-
step: firstStep,
|
| 667 |
-
current: 0,
|
| 668 |
-
total: totalSteps,
|
| 669 |
-
})
|
| 670 |
-
}
|
| 671 |
-
}
|
| 672 |
-
|
| 673 |
-
await setProgres(0)
|
| 674 |
-
for (let i = 0; i < actions.length; i++) {
|
| 675 |
-
if (get().operation?.cancelRequested) {
|
| 676 |
-
break
|
| 677 |
-
}
|
| 678 |
-
|
| 679 |
-
const action = actions[i]
|
| 680 |
-
|
| 681 |
-
if (onStepChange) {
|
| 682 |
-
await onStepChange(action)
|
| 683 |
-
}
|
| 684 |
-
|
| 685 |
-
if (shouldTrackOperation) {
|
| 686 |
-
get().updateOperation({
|
| 687 |
-
step: action,
|
| 688 |
-
current: i,
|
| 689 |
-
total: totalSteps,
|
| 690 |
-
})
|
| 691 |
-
}
|
| 692 |
-
|
| 693 |
-
await (get() as any)[actions[i]](_, index)
|
| 694 |
-
await setProgres(Math.floor(((i + 1) / totalSteps) * 100))
|
| 695 |
-
}
|
| 696 |
-
|
| 697 |
-
const cancelled = get().operation?.cancelRequested
|
| 698 |
-
|
| 699 |
-
if (shouldTrackOperation && ownsOperation && !cancelled) {
|
| 700 |
-
get().updateOperation({ current: totalSteps, total: totalSteps })
|
| 701 |
-
}
|
| 702 |
-
|
| 703 |
-
if (shouldTrackOperation && ownsOperation) {
|
| 704 |
get().finishOperation()
|
| 705 |
-
}
|
| 706 |
-
|
| 707 |
-
if (!onProgress) {
|
| 708 |
await get().clearProgress()
|
| 709 |
}
|
| 710 |
-
|
| 711 |
-
if (cancelled) {
|
| 712 |
-
return
|
| 713 |
-
}
|
| 714 |
},
|
| 715 |
|
| 716 |
inpaintAndRenderImage: async (_, index) => {
|
|
@@ -723,59 +642,23 @@ export const useAppStore = create<AppState>((set, get) => {
|
|
| 723 |
const total = get().totalPages
|
| 724 |
if (!total) return
|
| 725 |
|
| 726 |
-
if (!get().llmReady) {
|
| 727 |
-
await get().llmList()
|
| 728 |
-
await get().llmToggleLoadUnload()
|
| 729 |
-
}
|
| 730 |
-
|
| 731 |
get().startOperation({
|
| 732 |
type: 'process-all',
|
| 733 |
cancellable: true,
|
| 734 |
current: 0,
|
| 735 |
total,
|
| 736 |
})
|
| 737 |
-
|
| 738 |
-
|
| 739 |
-
|
| 740 |
-
|
| 741 |
-
|
| 742 |
-
await get().setCurrentDocumentIndex(index)
|
| 743 |
-
get().updateOperation({
|
| 744 |
-
current: index,
|
| 745 |
-
total,
|
| 746 |
-
})
|
| 747 |
-
|
| 748 |
-
await get().processImage(null, index, {
|
| 749 |
-
onProgress: async (progress) => {
|
| 750 |
-
if (get().operation?.cancelRequested) return
|
| 751 |
-
const currentValue = index + progress / 100
|
| 752 |
-
const overall = Math.min(
|
| 753 |
-
100,
|
| 754 |
-
Math.round((currentValue / total) * 100),
|
| 755 |
-
)
|
| 756 |
-
await get().setProgress(overall)
|
| 757 |
-
get().updateOperation({ current: currentValue, total })
|
| 758 |
-
},
|
| 759 |
-
onStepChange: (step) => {
|
| 760 |
-
if (get().operation?.cancelRequested) return
|
| 761 |
-
get().updateOperation({ step })
|
| 762 |
-
},
|
| 763 |
-
skipOperationTracking: true,
|
| 764 |
})
|
| 765 |
-
|
| 766 |
-
|
| 767 |
-
|
| 768 |
-
|
| 769 |
-
|
| 770 |
-
get().updateOperation({ current: index + 1, total })
|
| 771 |
-
}
|
| 772 |
-
|
| 773 |
-
if (!get().operation?.cancelRequested) {
|
| 774 |
-
get().updateOperation({ current: total, total })
|
| 775 |
}
|
| 776 |
-
|
| 777 |
-
await get().clearProgress()
|
| 778 |
-
get().finishOperation()
|
| 779 |
},
|
| 780 |
|
| 781 |
exportDocument: async () => {
|
|
@@ -807,6 +690,51 @@ type ConfigState = {
|
|
| 807 |
setBrushConfig: (config: Partial<ConfigState['brushConfig']>) => void
|
| 808 |
}
|
| 809 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 810 |
export const useConfigStore = create<ConfigState>((set) => ({
|
| 811 |
brushConfig: {
|
| 812 |
size: 36,
|
|
|
|
| 1 |
'use client'
|
| 2 |
|
| 3 |
import { create } from 'zustand'
|
| 4 |
+
import {
|
| 5 |
+
invoke,
|
| 6 |
+
subscribeProcessProgress,
|
| 7 |
+
getCurrentWindow,
|
| 8 |
+
ProgressBarStatus,
|
| 9 |
+
type ProcessProgress,
|
| 10 |
+
} from '@/lib/backend'
|
| 11 |
import {
|
| 12 |
Document,
|
| 13 |
InpaintRegion,
|
|
|
|
| 21 |
languages: string[]
|
| 22 |
}
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
const createTextBlockSyncer = () => {
|
| 25 |
let pending: {
|
| 26 |
index: number
|
|
|
|
| 205 |
index?: number,
|
| 206 |
textBlockIndex?: number,
|
| 207 |
) => Promise<void>
|
| 208 |
+
processImage: (_?: any, index?: number) => Promise<void>
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
inpaintAndRenderImage: (_?: any, index?: number) => Promise<void>
|
| 210 |
processAllImages: () => Promise<void>
|
| 211 |
exportDocument: () => Promise<void>
|
|
|
|
| 609 |
void get().renderTextBlock(undefined, index, textBlockIndex)
|
| 610 |
}
|
| 611 |
},
|
| 612 |
+
// Auto-processing: delegates to the backend pipeline; progress via SSE
|
| 613 |
+
processImage: async (_, index) => {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 614 |
index = index ?? get().currentDocumentIndex
|
| 615 |
+
get().startOperation({
|
| 616 |
+
type: 'process-current',
|
| 617 |
+
cancellable: true,
|
| 618 |
+
current: 0,
|
| 619 |
+
total: 5,
|
| 620 |
+
})
|
| 621 |
+
try {
|
| 622 |
+
await invoke('process', {
|
| 623 |
+
index,
|
| 624 |
+
llmModelId: get().llmSelectedModel,
|
| 625 |
+
language: get().llmSelectedLanguage,
|
| 626 |
+
shaderEffect: get().renderEffect,
|
| 627 |
+
})
|
| 628 |
+
} catch (err) {
|
| 629 |
+
console.error('Failed to start processing:', err)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 630 |
get().finishOperation()
|
|
|
|
|
|
|
|
|
|
| 631 |
await get().clearProgress()
|
| 632 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 633 |
},
|
| 634 |
|
| 635 |
inpaintAndRenderImage: async (_, index) => {
|
|
|
|
| 642 |
const total = get().totalPages
|
| 643 |
if (!total) return
|
| 644 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 645 |
get().startOperation({
|
| 646 |
type: 'process-all',
|
| 647 |
cancellable: true,
|
| 648 |
current: 0,
|
| 649 |
total,
|
| 650 |
})
|
| 651 |
+
try {
|
| 652 |
+
await invoke('process', {
|
| 653 |
+
llmModelId: get().llmSelectedModel,
|
| 654 |
+
language: get().llmSelectedLanguage,
|
| 655 |
+
shaderEffect: get().renderEffect,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 656 |
})
|
| 657 |
+
} catch (err) {
|
| 658 |
+
console.error('Failed to start processing:', err)
|
| 659 |
+
get().finishOperation()
|
| 660 |
+
await get().clearProgress()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 661 |
}
|
|
|
|
|
|
|
|
|
|
| 662 |
},
|
| 663 |
|
| 664 |
exportDocument: async () => {
|
|
|
|
| 690 |
setBrushConfig: (config: Partial<ConfigState['brushConfig']>) => void
|
| 691 |
}
|
| 692 |
|
| 693 |
+
// Subscribe to pipeline progress at module level so the EventSource is
|
| 694 |
+
// connected before any pipeline is started (avoids race with lazy component mount).
|
| 695 |
+
if (typeof window !== 'undefined') {
|
| 696 |
+
subscribeProcessProgress((progress: ProcessProgress) => {
|
| 697 |
+
const s = useAppStore.getState()
|
| 698 |
+
if (progress.status === 'running') {
|
| 699 |
+
const isSingleDoc = progress.totalDocuments <= 1
|
| 700 |
+
s.updateOperation({
|
| 701 |
+
step: progress.step ?? undefined,
|
| 702 |
+
current: isSingleDoc
|
| 703 |
+
? progress.currentStepIndex
|
| 704 |
+
: progress.currentDocument +
|
| 705 |
+
(progress.totalSteps > 0
|
| 706 |
+
? progress.currentStepIndex / progress.totalSteps
|
| 707 |
+
: 0),
|
| 708 |
+
total: isSingleDoc ? progress.totalSteps : progress.totalDocuments,
|
| 709 |
+
})
|
| 710 |
+
getCurrentWindow()
|
| 711 |
+
.setProgressBar({
|
| 712 |
+
status: ProgressBarStatus.Normal,
|
| 713 |
+
progress: progress.overallPercent,
|
| 714 |
+
})
|
| 715 |
+
.catch(() => {})
|
| 716 |
+
s.refreshCurrentDocument()
|
| 717 |
+
} else {
|
| 718 |
+
// Set to 100% first, then wait for the CSS transition to finish
|
| 719 |
+
// before removing the bubble.
|
| 720 |
+
s.updateOperation({
|
| 721 |
+
current: s.operation?.total,
|
| 722 |
+
total: s.operation?.total,
|
| 723 |
+
})
|
| 724 |
+
getCurrentWindow()
|
| 725 |
+
.setProgressBar({ status: ProgressBarStatus.Normal, progress: 100 })
|
| 726 |
+
.catch(() => {})
|
| 727 |
+
s.refreshCurrentDocument()
|
| 728 |
+
setTimeout(() => {
|
| 729 |
+
useAppStore.getState().finishOperation()
|
| 730 |
+
getCurrentWindow()
|
| 731 |
+
.setProgressBar({ status: ProgressBarStatus.None, progress: 0 })
|
| 732 |
+
.catch(() => {})
|
| 733 |
+
}, 1000)
|
| 734 |
+
}
|
| 735 |
+
})
|
| 736 |
+
}
|
| 737 |
+
|
| 738 |
export const useConfigStore = create<ConfigState>((set) => ({
|
| 739 |
brushConfig: {
|
| 740 |
size: 36,
|