Mayo commited on
Commit
3ab7cd8
·
unverified ·
1 Parent(s): 5f0bc6b

feat: Rust-based processing

Browse files
koharu/src/app.rs CHANGED
@@ -59,6 +59,7 @@ pub struct AppResources {
59
  pub llm: Arc<llm::Model>,
60
  pub renderer: Arc<Renderer>,
61
  pub ml_device: DeviceName,
 
62
  }
63
 
64
  #[derive(Parser)]
@@ -208,6 +209,7 @@ async fn build_resources(cpu: bool, _register_file_assoc: bool) -> Result<AppRes
208
  llm,
209
  renderer,
210
  ml_device,
 
211
  })
212
  }
213
 
 
59
  pub llm: Arc<llm::Model>,
60
  pub renderer: Arc<Renderer>,
61
  pub ml_device: DeviceName,
62
+ pub pipeline: Arc<RwLock<Option<crate::pipeline::PipelineHandle>>>,
63
  }
64
 
65
  #[derive(Parser)]
 
209
  llm,
210
  renderer,
211
  ml_device,
212
+ pipeline: Arc::new(RwLock::new(None)),
213
  })
214
  }
215
 
koharu/src/endpoints.rs CHANGED
@@ -27,7 +27,7 @@ use crate::{
27
  khr::{deserialize_khr, has_khr_magic, serialize_khr},
28
  llm,
29
  result::Result,
30
- state::{Document, TextBlock, TextStyle},
31
  version,
32
  };
33
 
@@ -194,7 +194,7 @@ pub async fn detect(
194
  State(state): State<AppResources>,
195
  Json(payload): Json<IndexPayload>,
196
  ) -> Result<Json<()>> {
197
- let snapshot = {
198
  let guard = state.state.read().await;
199
  guard
200
  .documents
@@ -203,44 +203,14 @@ pub async fn detect(
203
  .ok_or_else(|| anyhow::anyhow!("Document not found"))?
204
  };
205
 
206
- let (text_blocks, segment) = state.ml.detect_dialog(&snapshot.image).await?;
207
- let mut updated = snapshot.clone();
208
- updated.text_blocks = text_blocks;
209
- updated.segment = Some(segment);
210
-
211
- if !updated.text_blocks.is_empty() {
212
- let images: Vec<image::DynamicImage> = updated
213
- .text_blocks
214
- .iter()
215
- .map(|block| {
216
- updated.image.crop_imm(
217
- block.x as u32,
218
- block.y as u32,
219
- block.width as u32,
220
- block.height as u32,
221
- )
222
- })
223
- .collect();
224
-
225
- let font_predictions = state.ml.detect_fonts(&images, 1).await?;
226
- for (block, prediction) in updated.text_blocks.iter_mut().zip(font_predictions) {
227
- let color = prediction.text_color;
228
- let font_size = (prediction.font_size_px > 0.0).then_some(prediction.font_size_px);
229
- block.font_prediction = Some(prediction);
230
- block.style = Some(TextStyle {
231
- font_size,
232
- color: [color[0], color[1], color[2], 255],
233
- ..Default::default()
234
- });
235
- }
236
- }
237
 
238
  let mut guard = state.state.write().await;
239
  let document = guard
240
  .documents
241
  .get_mut(payload.index)
242
  .ok_or_else(|| anyhow::anyhow!("Document not found"))?;
243
- *document = updated;
244
  Ok(Json(()))
245
  }
246
 
@@ -249,7 +219,7 @@ pub async fn ocr(
249
  State(state): State<AppResources>,
250
  Json(payload): Json<IndexPayload>,
251
  ) -> Result<Json<()>> {
252
- let snapshot = {
253
  let guard = state.state.read().await;
254
  guard
255
  .documents
@@ -258,16 +228,14 @@ pub async fn ocr(
258
  .ok_or_else(|| anyhow::anyhow!("Document not found"))?
259
  };
260
 
261
- let text_blocks = state.ml.ocr(&snapshot.image, &snapshot.text_blocks).await?;
262
- let mut updated = snapshot;
263
- updated.text_blocks = text_blocks;
264
 
265
  let mut guard = state.state.write().await;
266
  let document = guard
267
  .documents
268
  .get_mut(payload.index)
269
  .ok_or_else(|| anyhow::anyhow!("Document not found"))?;
270
- *document = updated;
271
  Ok(Json(()))
272
  }
273
 
@@ -276,7 +244,7 @@ pub async fn inpaint(
276
  State(state): State<AppResources>,
277
  Json(payload): Json<IndexPayload>,
278
  ) -> Result<Json<()>> {
279
- let snapshot = {
280
  let guard = state.state.read().await;
281
  guard
282
  .documents
@@ -285,43 +253,14 @@ pub async fn inpaint(
285
  .ok_or_else(|| anyhow::anyhow!("Document not found"))?
286
  };
287
 
288
- let segment = snapshot
289
- .segment
290
- .as_ref()
291
- .ok_or_else(|| anyhow::anyhow!("Segment image not found"))?;
292
- let text_blocks = &snapshot.text_blocks;
293
- let mut segment_data = segment.to_rgba8();
294
- let (seg_width, seg_height) = segment_data.dimensions();
295
-
296
- for y in 0..seg_height {
297
- for x in 0..seg_width {
298
- let pixel = segment_data.get_pixel_mut(x, y);
299
- if pixel.0 != [0, 0, 0, 255] {
300
- let inside_any_block = text_blocks.iter().any(|block| {
301
- x >= block.x as u32
302
- && x < (block.x + block.width) as u32
303
- && y >= block.y as u32
304
- && y < (block.y + block.height) as u32
305
- });
306
- if !inside_any_block {
307
- *pixel = image::Rgba([0, 0, 0, 255]);
308
- }
309
- }
310
- }
311
- }
312
-
313
- let mask = SerializableDynamicImage::from(image::DynamicImage::ImageRgba8(segment_data));
314
- let inpainted = state.ml.inpaint(&snapshot.image, &mask).await?;
315
-
316
- let mut updated = snapshot;
317
- updated.inpainted = Some(inpainted);
318
 
319
  let mut guard = state.state.write().await;
320
  let document = guard
321
  .documents
322
  .get_mut(payload.index)
323
  .ok_or_else(|| anyhow::anyhow!("Document not found"))?;
324
- *document = updated;
325
  Ok(Json(()))
326
  }
327
 
@@ -544,7 +483,7 @@ pub async fn inpaint_partial(
544
  SerializableDynamicImage(snapshot.image.crop_imm(x0, y0, crop_width, crop_height));
545
  let mask_crop = SerializableDynamicImage(mask_image.crop_imm(x0, y0, crop_width, crop_height));
546
 
547
- let inpainted_crop = state.ml.inpaint(&image_crop, &mask_crop).await?;
548
 
549
  let mut stitched = snapshot
550
  .inpainted
@@ -849,3 +788,59 @@ pub async fn download_progress()
849
  });
850
  Sse::new(stream).keep_alive(KeepAlive::default())
851
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  khr::{deserialize_khr, has_khr_magic, serialize_khr},
28
  llm,
29
  result::Result,
30
+ state::{Document, TextBlock},
31
  version,
32
  };
33
 
 
194
  State(state): State<AppResources>,
195
  Json(payload): Json<IndexPayload>,
196
  ) -> Result<Json<()>> {
197
+ let mut snapshot = {
198
  let guard = state.state.read().await;
199
  guard
200
  .documents
 
203
  .ok_or_else(|| anyhow::anyhow!("Document not found"))?
204
  };
205
 
206
+ state.ml.detect(&mut snapshot).await?;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
 
208
  let mut guard = state.state.write().await;
209
  let document = guard
210
  .documents
211
  .get_mut(payload.index)
212
  .ok_or_else(|| anyhow::anyhow!("Document not found"))?;
213
+ *document = snapshot;
214
  Ok(Json(()))
215
  }
216
 
 
219
  State(state): State<AppResources>,
220
  Json(payload): Json<IndexPayload>,
221
  ) -> Result<Json<()>> {
222
+ let mut snapshot = {
223
  let guard = state.state.read().await;
224
  guard
225
  .documents
 
228
  .ok_or_else(|| anyhow::anyhow!("Document not found"))?
229
  };
230
 
231
+ state.ml.ocr(&mut snapshot).await?;
 
 
232
 
233
  let mut guard = state.state.write().await;
234
  let document = guard
235
  .documents
236
  .get_mut(payload.index)
237
  .ok_or_else(|| anyhow::anyhow!("Document not found"))?;
238
+ *document = snapshot;
239
  Ok(Json(()))
240
  }
241
 
 
244
  State(state): State<AppResources>,
245
  Json(payload): Json<IndexPayload>,
246
  ) -> Result<Json<()>> {
247
+ let mut snapshot = {
248
  let guard = state.state.read().await;
249
  guard
250
  .documents
 
253
  .ok_or_else(|| anyhow::anyhow!("Document not found"))?
254
  };
255
 
256
+ state.ml.inpaint(&mut snapshot).await?;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
 
258
  let mut guard = state.state.write().await;
259
  let document = guard
260
  .documents
261
  .get_mut(payload.index)
262
  .ok_or_else(|| anyhow::anyhow!("Document not found"))?;
263
+ *document = snapshot;
264
  Ok(Json(()))
265
  }
266
 
 
483
  SerializableDynamicImage(snapshot.image.crop_imm(x0, y0, crop_width, crop_height));
484
  let mask_crop = SerializableDynamicImage(mask_image.crop_imm(x0, y0, crop_width, crop_height));
485
 
486
+ let inpainted_crop = state.ml.inpaint_raw(&image_crop, &mask_crop).await?;
487
 
488
  let mut stitched = snapshot
489
  .inpainted
 
788
  });
789
  Sse::new(stream).keep_alive(KeepAlive::default())
790
  }
791
+
792
+ // --- Auto-processing pipeline endpoints ---
793
+
794
+ /// Start the processing pipeline. Returns immediately; progress is available
795
+ /// via the `process_progress` SSE endpoint.
796
+ pub async fn process(
797
+ State(state): State<AppResources>,
798
+ Json(payload): Json<crate::pipeline::ProcessRequest>,
799
+ ) -> Result<Json<()>> {
800
+ {
801
+ let guard = state.pipeline.read().await;
802
+ if guard.is_some() {
803
+ Err(anyhow::anyhow!("A processing pipeline is already running"))?;
804
+ }
805
+ }
806
+
807
+ let cancel = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
808
+ {
809
+ let mut guard = state.pipeline.write().await;
810
+ *guard = Some(crate::pipeline::PipelineHandle {
811
+ cancel: cancel.clone(),
812
+ });
813
+ }
814
+
815
+ let resources = state.clone();
816
+ tokio::spawn(async move {
817
+ crate::pipeline::run_pipeline(resources, payload, cancel).await;
818
+ });
819
+
820
+ Ok(Json(()))
821
+ }
822
+
823
+ pub async fn process_progress()
824
+ -> Sse<impl futures::Stream<Item = std::result::Result<Event, Infallible>>> {
825
+ let rx = crate::pipeline::subscribe();
826
+ let stream = futures::stream::unfold(rx, |mut rx| async {
827
+ loop {
828
+ match rx.recv().await {
829
+ Ok(p) => return Some((Ok(Event::default().json_data(p).unwrap()), rx)),
830
+ Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue,
831
+ Err(tokio::sync::broadcast::error::RecvError::Closed) => return None,
832
+ }
833
+ }
834
+ });
835
+ Sse::new(stream).keep_alive(KeepAlive::default())
836
+ }
837
+
838
+ pub async fn process_cancel(State(state): State<AppResources>) -> Result<Json<()>> {
839
+ let guard = state.pipeline.read().await;
840
+ if let Some(handle) = guard.as_ref() {
841
+ handle
842
+ .cancel
843
+ .store(true, std::sync::atomic::Ordering::Relaxed);
844
+ }
845
+ Ok(Json(()))
846
+ }
koharu/src/lib.rs CHANGED
@@ -5,6 +5,7 @@ pub mod image;
5
  pub mod khr;
6
  pub mod llm;
7
  pub mod ml;
 
8
  pub mod renderer;
9
  pub mod result;
10
  pub mod server;
 
5
  pub mod khr;
6
  pub mod llm;
7
  pub mod ml;
8
+ pub mod pipeline;
9
  pub mod renderer;
10
  pub mod result;
11
  pub mod server;
koharu/src/ml.rs CHANGED
@@ -1,12 +1,12 @@
1
  use anyhow::Result;
2
- use image::DynamicImage;
3
  use koharu_ml::comic_text_detector::{self, ComicTextDetector};
4
  use koharu_ml::font_detector::{self, FontDetector};
5
  use koharu_ml::lama::{self, Lama};
6
  use koharu_ml::manga_ocr::{self, MangaOcr};
7
 
8
  use crate::image::SerializableDynamicImage;
9
- use crate::state::TextBlock;
10
 
11
  const NEAR_BLACK_THRESHOLD: u8 = 12;
12
  const GRAY_NEAR_BLACK_THRESHOLD: u8 = 60;
@@ -85,11 +85,10 @@ impl Model {
85
  })
86
  }
87
 
88
- pub async fn detect_dialog(
89
- &self,
90
- image: &SerializableDynamicImage,
91
- ) -> Result<(Vec<TextBlock>, SerializableDynamicImage)> {
92
- let (bboxes, segment) = self.dialog_detector.inference(image)?;
93
 
94
  let mut text_blocks: Vec<TextBlock> = bboxes
95
  .into_iter()
@@ -109,22 +108,51 @@ impl Model {
109
  .unwrap_or(std::cmp::Ordering::Equal)
110
  });
111
 
112
- Ok((text_blocks, DynamicImage::ImageLuma8(segment).into()))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  }
114
 
115
- pub async fn ocr(
116
- &self,
117
- image: &SerializableDynamicImage,
118
- blocks: &[TextBlock],
119
- ) -> Result<Vec<TextBlock>> {
120
- if blocks.is_empty() {
121
- return Ok(Vec::new());
122
  }
123
 
124
- let crops: Vec<DynamicImage> = blocks
 
125
  .iter()
126
  .map(|block| {
127
- image.crop_imm(
128
  block.x as u32,
129
  block.y as u32,
130
  block.width as u32,
@@ -134,24 +162,54 @@ impl Model {
134
  .collect();
135
  let texts = self.ocr.inference(&crops)?;
136
 
137
- Ok(blocks
138
- .iter()
139
- .cloned()
140
- .zip(texts.into_iter())
141
- .map(|(block, text)| TextBlock {
142
- text: text.into(),
143
- ..block
144
- })
145
- .collect())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  }
147
 
148
- pub async fn inpaint(
 
149
  &self,
150
  image: &SerializableDynamicImage,
151
  mask: &SerializableDynamicImage,
152
  ) -> Result<SerializableDynamicImage> {
153
  let result = self.lama.inference(image, mask)?;
154
-
155
  Ok(result.into())
156
  }
157
 
 
1
  use anyhow::Result;
2
+ use image::{DynamicImage, Rgba};
3
  use koharu_ml::comic_text_detector::{self, ComicTextDetector};
4
  use koharu_ml::font_detector::{self, FontDetector};
5
  use koharu_ml::lama::{self, Lama};
6
  use koharu_ml::manga_ocr::{self, MangaOcr};
7
 
8
  use crate::image::SerializableDynamicImage;
9
+ use crate::state::{Document, TextBlock, TextStyle};
10
 
11
  const NEAR_BLACK_THRESHOLD: u8 = 12;
12
  const GRAY_NEAR_BLACK_THRESHOLD: u8 = 60;
 
85
  })
86
  }
87
 
88
+ /// Detect text blocks and fonts in a document.
89
+ /// Sets `doc.text_blocks` (with font predictions/styles) and `doc.segment`.
90
+ pub async fn detect(&self, doc: &mut Document) -> Result<()> {
91
+ let (bboxes, segment) = self.dialog_detector.inference(&doc.image)?;
 
92
 
93
  let mut text_blocks: Vec<TextBlock> = bboxes
94
  .into_iter()
 
108
  .unwrap_or(std::cmp::Ordering::Equal)
109
  });
110
 
111
+ doc.text_blocks = text_blocks;
112
+ doc.segment = Some(DynamicImage::ImageLuma8(segment).into());
113
+
114
+ if !doc.text_blocks.is_empty() {
115
+ let images: Vec<DynamicImage> = doc
116
+ .text_blocks
117
+ .iter()
118
+ .map(|block| {
119
+ doc.image.crop_imm(
120
+ block.x as u32,
121
+ block.y as u32,
122
+ block.width as u32,
123
+ block.height as u32,
124
+ )
125
+ })
126
+ .collect();
127
+
128
+ let font_predictions = self.detect_fonts(&images, 1).await?;
129
+ for (block, prediction) in doc.text_blocks.iter_mut().zip(font_predictions) {
130
+ let color = prediction.text_color;
131
+ let font_size = (prediction.font_size_px > 0.0).then_some(prediction.font_size_px);
132
+ block.font_prediction = Some(prediction);
133
+ block.style = Some(TextStyle {
134
+ font_size,
135
+ color: [color[0], color[1], color[2], 255],
136
+ ..Default::default()
137
+ });
138
+ }
139
+ }
140
+
141
+ Ok(())
142
  }
143
 
144
+ /// Run OCR on all text blocks in the document.
145
+ /// Updates `doc.text_blocks` with recognized text.
146
+ pub async fn ocr(&self, doc: &mut Document) -> Result<()> {
147
+ if doc.text_blocks.is_empty() {
148
+ return Ok(());
 
 
149
  }
150
 
151
+ let crops: Vec<DynamicImage> = doc
152
+ .text_blocks
153
  .iter()
154
  .map(|block| {
155
+ doc.image.crop_imm(
156
  block.x as u32,
157
  block.y as u32,
158
  block.width as u32,
 
162
  .collect();
163
  let texts = self.ocr.inference(&crops)?;
164
 
165
+ for (block, text) in doc.text_blocks.iter_mut().zip(texts) {
166
+ block.text = text.into();
167
+ }
168
+
169
+ Ok(())
170
+ }
171
+
172
+ /// Inpaint text regions in the document.
173
+ /// Builds mask from `doc.segment` + `doc.text_blocks`, sets `doc.inpainted`.
174
+ pub async fn inpaint(&self, doc: &mut Document) -> Result<()> {
175
+ let segment = doc
176
+ .segment
177
+ .as_ref()
178
+ .ok_or_else(|| anyhow::anyhow!("Segment image not found"))?;
179
+ let mut segment_data = segment.to_rgba8();
180
+ let (seg_width, seg_height) = segment_data.dimensions();
181
+
182
+ for y in 0..seg_height {
183
+ for x in 0..seg_width {
184
+ let pixel = segment_data.get_pixel_mut(x, y);
185
+ if pixel.0 != [0, 0, 0, 255] {
186
+ let inside_any_block = doc.text_blocks.iter().any(|block| {
187
+ x >= block.x as u32
188
+ && x < (block.x + block.width) as u32
189
+ && y >= block.y as u32
190
+ && y < (block.y + block.height) as u32
191
+ });
192
+ if !inside_any_block {
193
+ *pixel = Rgba([0, 0, 0, 255]);
194
+ }
195
+ }
196
+ }
197
+ }
198
+
199
+ let mask = SerializableDynamicImage::from(DynamicImage::ImageRgba8(segment_data));
200
+ let result = self.lama.inference(&doc.image, &mask)?;
201
+ doc.inpainted = Some(result.into());
202
+
203
+ Ok(())
204
  }
205
 
206
+ /// Low-level inpaint: inpaint a specific image region with a mask.
207
+ pub async fn inpaint_raw(
208
  &self,
209
  image: &SerializableDynamicImage,
210
  mask: &SerializableDynamicImage,
211
  ) -> Result<SerializableDynamicImage> {
212
  let result = self.lama.inference(image, mask)?;
 
213
  Ok(result.into())
214
  }
215
 
koharu/src/pipeline.rs ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use std::sync::{
2
+ Arc,
3
+ atomic::{AtomicBool, Ordering},
4
+ };
5
+ use std::time::Duration;
6
+
7
+ use koharu_ml::llm::ModelId;
8
+ use koharu_renderer::renderer::TextShaderEffect;
9
+ use once_cell::sync::Lazy;
10
+ use serde::{Deserialize, Serialize};
11
+ use std::str::FromStr;
12
+ use tokio::sync::broadcast;
13
+
14
+ use crate::app::AppResources;
15
+
16
+ /// Steps in the processing pipeline.
17
+ #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
18
+ #[serde(rename_all = "camelCase")]
19
+ pub enum PipelineStep {
20
+ Detect,
21
+ Ocr,
22
+ Inpaint,
23
+ LlmGenerate,
24
+ Render,
25
+ }
26
+
27
+ impl PipelineStep {
28
+ pub const ALL: &[PipelineStep] = &[
29
+ PipelineStep::Detect,
30
+ PipelineStep::Ocr,
31
+ PipelineStep::Inpaint,
32
+ PipelineStep::LlmGenerate,
33
+ PipelineStep::Render,
34
+ ];
35
+ }
36
+
37
+ /// Status of the pipeline.
38
+ #[derive(Debug, Clone, Serialize, PartialEq)]
39
+ #[serde(rename_all = "camelCase")]
40
+ pub enum PipelineStatus {
41
+ Running,
42
+ Completed,
43
+ Cancelled,
44
+ Failed(String),
45
+ }
46
+
47
+ /// SSE event payload sent to the frontend.
48
+ #[derive(Debug, Clone, Serialize)]
49
+ #[serde(rename_all = "camelCase")]
50
+ pub struct PipelineProgress {
51
+ pub status: PipelineStatus,
52
+ pub step: Option<PipelineStep>,
53
+ pub current_document: usize,
54
+ pub total_documents: usize,
55
+ pub current_step_index: usize,
56
+ pub total_steps: usize,
57
+ pub overall_percent: u8,
58
+ }
59
+
60
+ /// Request payload for the auto-process endpoint.
61
+ #[derive(Debug, Deserialize)]
62
+ #[serde(rename_all = "camelCase")]
63
+ pub struct ProcessRequest {
64
+ /// None means all documents; Some(i) means single document.
65
+ pub index: Option<usize>,
66
+ /// LLM model id to use (will load if not already loaded).
67
+ pub llm_model_id: Option<String>,
68
+ /// Target language for translation.
69
+ pub language: Option<String>,
70
+ /// Shader effect for rendering.
71
+ pub shader_effect: Option<TextShaderEffect>,
72
+ }
73
+
74
+ /// Handle to a running pipeline, used for cancellation.
75
+ pub struct PipelineHandle {
76
+ pub cancel: Arc<AtomicBool>,
77
+ }
78
+
79
+ // Global broadcast channel for pipeline progress (mirrors download.rs pattern).
80
+ static PIPELINE_TX: Lazy<broadcast::Sender<PipelineProgress>> =
81
+ Lazy::new(|| broadcast::channel(256).0);
82
+
83
+ pub fn subscribe() -> broadcast::Receiver<PipelineProgress> {
84
+ PIPELINE_TX.subscribe()
85
+ }
86
+
87
+ fn emit(progress: PipelineProgress) {
88
+ let _ = PIPELINE_TX.send(progress);
89
+ }
90
+
91
+ fn compute_percent(doc: usize, step: usize, total_docs: usize, total_steps: usize) -> u8 {
92
+ let done_units = doc * total_steps + step;
93
+ let total_units = total_docs * total_steps;
94
+ if total_units == 0 {
95
+ return 0;
96
+ }
97
+ ((done_units as f64 / total_units as f64) * 100.0).round() as u8
98
+ }
99
+
100
+ /// Run the processing pipeline. Called from a spawned task.
101
+ pub async fn run_pipeline(
102
+ resources: AppResources,
103
+ request: ProcessRequest,
104
+ cancel: Arc<AtomicBool>,
105
+ ) {
106
+ let result = run_pipeline_inner(&resources, &request, &cancel).await;
107
+
108
+ let total_docs = match &request.index {
109
+ Some(_) => 1,
110
+ None => resources.state.read().await.documents.len(),
111
+ };
112
+
113
+ match result {
114
+ Ok(()) if cancel.load(Ordering::Relaxed) => {
115
+ emit(PipelineProgress {
116
+ status: PipelineStatus::Cancelled,
117
+ step: None,
118
+ current_document: total_docs,
119
+ total_documents: total_docs,
120
+ current_step_index: 0,
121
+ total_steps: PipelineStep::ALL.len(),
122
+ overall_percent: 0,
123
+ });
124
+ }
125
+ Ok(()) => {
126
+ emit(PipelineProgress {
127
+ status: PipelineStatus::Completed,
128
+ step: None,
129
+ current_document: total_docs,
130
+ total_documents: total_docs,
131
+ current_step_index: PipelineStep::ALL.len(),
132
+ total_steps: PipelineStep::ALL.len(),
133
+ overall_percent: 100,
134
+ });
135
+ }
136
+ Err(e) => {
137
+ tracing::error!("Pipeline failed: {e:#}");
138
+ emit(PipelineProgress {
139
+ status: PipelineStatus::Failed(e.to_string()),
140
+ step: None,
141
+ current_document: 0,
142
+ total_documents: total_docs,
143
+ current_step_index: 0,
144
+ total_steps: PipelineStep::ALL.len(),
145
+ overall_percent: 0,
146
+ });
147
+ }
148
+ }
149
+
150
+ // Clear the pipeline handle
151
+ let mut guard = resources.pipeline.write().await;
152
+ *guard = None;
153
+ }
154
+
155
+ async fn run_pipeline_inner(
156
+ res: &AppResources,
157
+ req: &ProcessRequest,
158
+ cancel: &Arc<AtomicBool>,
159
+ ) -> anyhow::Result<()> {
160
+ let total_docs = {
161
+ let guard = res.state.read().await;
162
+ let len = guard.documents.len();
163
+ match req.index {
164
+ Some(i) if i >= len => {
165
+ anyhow::bail!("Document index {} out of range (have {})", i, len);
166
+ }
167
+ Some(_) => 1,
168
+ None => len,
169
+ }
170
+ };
171
+
172
+ if total_docs == 0 {
173
+ return Ok(());
174
+ }
175
+
176
+ // Ensure LLM is loaded
177
+ if let Some(model_id) = &req.llm_model_id {
178
+ if !res.llm.ready().await {
179
+ let id = ModelId::from_str(model_id)?;
180
+ res.llm.load(id).await;
181
+ // Poll until ready (with timeout)
182
+ for _ in 0..300 {
183
+ if res.llm.ready().await {
184
+ break;
185
+ }
186
+ tokio::time::sleep(Duration::from_millis(100)).await;
187
+ if cancel.load(Ordering::Relaxed) {
188
+ return Ok(());
189
+ }
190
+ }
191
+ if !res.llm.ready().await {
192
+ anyhow::bail!("LLM failed to load within timeout");
193
+ }
194
+ }
195
+ }
196
+
197
+ if let Some(locale) = req.language.as_ref() {
198
+ koharu_ml::set_locale(locale.clone());
199
+ }
200
+
201
+ let start_index = req.index.unwrap_or(0);
202
+ let end_index = req.index.map(|i| i + 1).unwrap_or(total_docs);
203
+ let total_steps = PipelineStep::ALL.len();
204
+
205
+ for (doc_ordinal, doc_index) in (start_index..end_index).enumerate() {
206
+ for (step_ordinal, step) in PipelineStep::ALL.iter().enumerate() {
207
+ if cancel.load(Ordering::Relaxed) {
208
+ return Ok(());
209
+ }
210
+
211
+ let overall = compute_percent(doc_ordinal, step_ordinal, total_docs, total_steps);
212
+ emit(PipelineProgress {
213
+ status: PipelineStatus::Running,
214
+ step: Some(*step),
215
+ current_document: doc_ordinal,
216
+ total_documents: total_docs,
217
+ current_step_index: step_ordinal,
218
+ total_steps,
219
+ overall_percent: overall,
220
+ });
221
+
222
+ // Give the runtime a chance to flush the SSE event before blocking ML work
223
+ tokio::task::yield_now().await;
224
+ tokio::time::sleep(Duration::from_millis(1)).await;
225
+
226
+ // Snapshot → process → write back
227
+ let mut snapshot =
228
+ {
229
+ let guard = res.state.read().await;
230
+ guard.documents.get(doc_index).cloned().ok_or_else(|| {
231
+ anyhow::anyhow!("Document not found at index {}", doc_index)
232
+ })?
233
+ };
234
+
235
+ match step {
236
+ PipelineStep::Detect => res.ml.detect(&mut snapshot).await?,
237
+ PipelineStep::Ocr => res.ml.ocr(&mut snapshot).await?,
238
+ PipelineStep::Inpaint => res.ml.inpaint(&mut snapshot).await?,
239
+ PipelineStep::LlmGenerate => {
240
+ res.llm.generate(&mut snapshot).await?;
241
+ }
242
+ PipelineStep::Render => {
243
+ res.renderer.render(
244
+ &mut snapshot,
245
+ None,
246
+ req.shader_effect.unwrap_or_default(),
247
+ )?;
248
+ }
249
+ }
250
+
251
+ let mut guard = res.state.write().await;
252
+ let document = guard
253
+ .documents
254
+ .get_mut(doc_index)
255
+ .ok_or_else(|| anyhow::anyhow!("Document not found at index {}", doc_index))?;
256
+ *document = snapshot;
257
+ }
258
+ }
259
+
260
+ Ok(())
261
+ }
koharu/src/server.rs CHANGED
@@ -76,6 +76,9 @@ fn build_router(state: AppResources) -> Router {
76
  .route("/api/llm_ready", get(llm_ready).post(llm_ready))
77
  .route("/api/llm_generate", post(llm_generate))
78
  .route("/api/download_progress", get(download_progress))
 
 
 
79
  .with_state(state)
80
  .layer(DefaultBodyLimit::max(1024 * 1024 * 1024))
81
  .layer(
 
76
  .route("/api/llm_ready", get(llm_ready).post(llm_ready))
77
  .route("/api/llm_generate", post(llm_generate))
78
  .route("/api/download_progress", get(download_progress))
79
+ .route("/api/process", post(process))
80
+ .route("/api/process_cancel", post(process_cancel))
81
+ .route("/api/process_progress", get(process_progress))
82
  .with_state(state)
83
  .layer(DefaultBodyLimit::max(1024 * 1024 * 1024))
84
  .layer(
scripts/dev.ts CHANGED
@@ -58,7 +58,7 @@ async function setupCuda() {
58
  }
59
 
60
  throw new Error(
61
- 'NVCC not found. Please install the CUDA Toolkit from https://developer.nvidia.com/cuda-12-9-1-download-archive',
62
  )
63
  }
64
 
 
58
  }
59
 
60
  throw new Error(
61
+ 'NVCC not found. Please install the CUDA Toolkit from https://developer.nvidia.com/cuda-downloads',
62
  )
63
  }
64
 
ui/components/ActivityBubble.tsx CHANGED
@@ -28,7 +28,7 @@ function ProgressBar({ percent }: { percent?: number }) {
28
  <div className='bg-muted relative h-1.5 flex-1 overflow-hidden rounded-full'>
29
  {typeof percent === 'number' ? (
30
  <div
31
- className='bg-primary h-full rounded-full transition-[width] duration-300'
32
  style={{ width: `${percent}%` }}
33
  />
34
  ) : (
 
28
  <div className='bg-muted relative h-1.5 flex-1 overflow-hidden rounded-full'>
29
  {typeof percent === 'number' ? (
30
  <div
31
+ className='bg-primary h-full rounded-full transition-[width] duration-700 ease-out'
32
  style={{ width: `${percent}%` }}
33
  />
34
  ) : (
ui/lib/backend.ts CHANGED
@@ -231,38 +231,43 @@ export type DownloadProgress = {
231
  status: 'Started' | 'Downloading' | 'Completed' | { Failed: string }
232
  }
233
 
234
- export function subscribeDownloadProgress(
235
- callback: (progress: DownloadProgress) => void,
 
 
 
 
 
 
 
 
 
 
 
236
  ): () => void {
237
- let stopped = false
238
 
239
  ;(async () => {
240
  await ensureInitialized()
241
- const es = new EventSource(`${apiBase}/download_progress`)
242
  es.onmessage = (event) => {
243
- if (stopped) {
244
- es.close()
245
- return
246
- }
247
  try {
248
- callback(JSON.parse(event.data) as DownloadProgress)
249
  } catch (_) {}
250
  }
251
- const check = () => {
252
- if (stopped) es.close()
253
- }
254
- const id = setInterval(check, 1000)
255
- es.onerror = () => {
256
- clearInterval(id)
257
- es.close()
258
- }
259
  })()
260
 
261
  return () => {
262
- stopped = true
263
  }
264
  }
265
 
 
 
 
 
 
 
266
  export const isTauri = isTauriEnv
267
 
268
  export const isMacOS = (): boolean => {
 
231
  status: 'Started' | 'Downloading' | 'Completed' | { Failed: string }
232
  }
233
 
234
+ export type ProcessProgress = {
235
+ status: 'running' | 'completed' | 'cancelled' | { failed: string }
236
+ step: string | null
237
+ currentDocument: number
238
+ totalDocuments: number
239
+ currentStepIndex: number
240
+ totalSteps: number
241
+ overallPercent: number
242
+ }
243
+
244
+ function subscribeSSE<T>(
245
+ endpoint: string,
246
+ callback: (data: T) => void,
247
  ): () => void {
248
+ let es: EventSource | null = null
249
 
250
  ;(async () => {
251
  await ensureInitialized()
252
+ es = new EventSource(`${apiBase}/${endpoint}`)
253
  es.onmessage = (event) => {
 
 
 
 
254
  try {
255
+ callback(JSON.parse(event.data) as T)
256
  } catch (_) {}
257
  }
 
 
 
 
 
 
 
 
258
  })()
259
 
260
  return () => {
261
+ es?.close()
262
  }
263
  }
264
 
265
+ export const subscribeDownloadProgress = (cb: (p: DownloadProgress) => void) =>
266
+ subscribeSSE<DownloadProgress>('download_progress', cb)
267
+
268
+ export const subscribeProcessProgress = (cb: (p: ProcessProgress) => void) =>
269
+ subscribeSSE<ProcessProgress>('process_progress', cb)
270
+
271
  export const isTauri = isTauriEnv
272
 
273
  export const isMacOS = (): boolean => {
ui/lib/operations.ts CHANGED
@@ -41,10 +41,15 @@ export const createOperationSlice = (set: any): OperationSlice => ({
41
  : { operation: undefined },
42
  ),
43
  finishOperation: () => set({ operation: undefined }),
44
- cancelOperation: () =>
45
  set((state: OperationSlice) =>
46
  state.operation
47
  ? { operation: { ...state.operation, cancelRequested: true } }
48
  : { operation: undefined },
49
- ),
 
 
 
 
 
50
  })
 
41
  : { operation: undefined },
42
  ),
43
  finishOperation: () => set({ operation: undefined }),
44
+ cancelOperation: () => {
45
  set((state: OperationSlice) =>
46
  state.operation
47
  ? { operation: { ...state.operation, cancelRequested: true } }
48
  : { operation: undefined },
49
+ )
50
+ // Also cancel backend pipeline if running
51
+ import('@/lib/backend').then(({ invoke }) => {
52
+ invoke('process_cancel').catch(() => {})
53
+ })
54
+ },
55
  })
ui/lib/store.ts CHANGED
@@ -1,7 +1,13 @@
1
  'use client'
2
 
3
  import { create } from 'zustand'
4
- import { invoke, getCurrentWindow, ProgressBarStatus } from '@/lib/backend'
 
 
 
 
 
 
5
  import {
6
  Document,
7
  InpaintRegion,
@@ -15,18 +21,6 @@ type LlmModelInfo = {
15
  languages: string[]
16
  }
17
 
18
- type ProcessAction = 'detect' | 'ocr' | 'inpaint' | 'llmGenerate' | 'render'
19
-
20
- type ProcessImageOptionsObject = {
21
- onProgress?: (progress: number) => Promise<void>
22
- onStepChange?: (step: ProcessAction) => Promise<void> | void
23
- skipOperationTracking?: boolean
24
- }
25
-
26
- type ProcessImageOptions =
27
- | ((progress: number) => Promise<void>)
28
- | ProcessImageOptionsObject
29
-
30
  const createTextBlockSyncer = () => {
31
  let pending: {
32
  index: number
@@ -211,11 +205,7 @@ type AppState = OperationSlice & {
211
  index?: number,
212
  textBlockIndex?: number,
213
  ) => Promise<void>
214
- processImage: (
215
- _?: any,
216
- index?: number,
217
- options?: ProcessImageOptions,
218
- ) => Promise<void>
219
  inpaintAndRenderImage: (_?: any, index?: number) => Promise<void>
220
  processAllImages: () => Promise<void>
221
  exportDocument: () => Promise<void>
@@ -619,98 +609,27 @@ export const useAppStore = create<AppState>((set, get) => {
619
  void get().renderTextBlock(undefined, index, textBlockIndex)
620
  }
621
  },
622
- // batch proceeses
623
- processImage: async (_, index, options) => {
624
- const normalizedOptions: ProcessImageOptionsObject =
625
- typeof options === 'function'
626
- ? { onProgress: options, skipOperationTracking: false }
627
- : (options ?? {})
628
-
629
- const { onProgress, onStepChange, skipOperationTracking } =
630
- normalizedOptions
631
- const operation = get().operation
632
- const isBatchRun = operation?.type === 'process-all'
633
-
634
- if (!get().llmReady) {
635
- await get().llmList()
636
- await get().llmToggleLoadUnload()
637
- }
638
-
639
  index = index ?? get().currentDocumentIndex
640
- console.log('Processing image at index', index)
641
- const setProgres = onProgress ?? get().setProgress
642
- const shouldTrackOperation = skipOperationTracking !== true && !isBatchRun
643
- const ownsOperation = shouldTrackOperation && !isBatchRun
644
-
645
- const actions: ProcessAction[] = [
646
- 'detect',
647
- 'ocr',
648
- 'inpaint',
649
- 'llmGenerate',
650
- 'render',
651
- ]
652
- const totalSteps = actions.length
653
-
654
- if (shouldTrackOperation) {
655
- const firstStep = actions[0] ?? 'detect'
656
- if (ownsOperation) {
657
- get().startOperation({
658
- type: 'process-current',
659
- step: firstStep,
660
- current: 0,
661
- total: totalSteps,
662
- cancellable: true,
663
- })
664
- } else {
665
- get().updateOperation({
666
- step: firstStep,
667
- current: 0,
668
- total: totalSteps,
669
- })
670
- }
671
- }
672
-
673
- await setProgres(0)
674
- for (let i = 0; i < actions.length; i++) {
675
- if (get().operation?.cancelRequested) {
676
- break
677
- }
678
-
679
- const action = actions[i]
680
-
681
- if (onStepChange) {
682
- await onStepChange(action)
683
- }
684
-
685
- if (shouldTrackOperation) {
686
- get().updateOperation({
687
- step: action,
688
- current: i,
689
- total: totalSteps,
690
- })
691
- }
692
-
693
- await (get() as any)[actions[i]](_, index)
694
- await setProgres(Math.floor(((i + 1) / totalSteps) * 100))
695
- }
696
-
697
- const cancelled = get().operation?.cancelRequested
698
-
699
- if (shouldTrackOperation && ownsOperation && !cancelled) {
700
- get().updateOperation({ current: totalSteps, total: totalSteps })
701
- }
702
-
703
- if (shouldTrackOperation && ownsOperation) {
704
  get().finishOperation()
705
- }
706
-
707
- if (!onProgress) {
708
  await get().clearProgress()
709
  }
710
-
711
- if (cancelled) {
712
- return
713
- }
714
  },
715
 
716
  inpaintAndRenderImage: async (_, index) => {
@@ -723,59 +642,23 @@ export const useAppStore = create<AppState>((set, get) => {
723
  const total = get().totalPages
724
  if (!total) return
725
 
726
- if (!get().llmReady) {
727
- await get().llmList()
728
- await get().llmToggleLoadUnload()
729
- }
730
-
731
  get().startOperation({
732
  type: 'process-all',
733
  cancellable: true,
734
  current: 0,
735
  total,
736
  })
737
-
738
- for (let index = 0; index < total; index++) {
739
- if (get().operation?.cancelRequested) break
740
-
741
- // Switch to this document
742
- await get().setCurrentDocumentIndex(index)
743
- get().updateOperation({
744
- current: index,
745
- total,
746
- })
747
-
748
- await get().processImage(null, index, {
749
- onProgress: async (progress) => {
750
- if (get().operation?.cancelRequested) return
751
- const currentValue = index + progress / 100
752
- const overall = Math.min(
753
- 100,
754
- Math.round((currentValue / total) * 100),
755
- )
756
- await get().setProgress(overall)
757
- get().updateOperation({ current: currentValue, total })
758
- },
759
- onStepChange: (step) => {
760
- if (get().operation?.cancelRequested) return
761
- get().updateOperation({ step })
762
- },
763
- skipOperationTracking: true,
764
  })
765
-
766
- if (get().operation?.cancelRequested) {
767
- break
768
- }
769
-
770
- get().updateOperation({ current: index + 1, total })
771
- }
772
-
773
- if (!get().operation?.cancelRequested) {
774
- get().updateOperation({ current: total, total })
775
  }
776
-
777
- await get().clearProgress()
778
- get().finishOperation()
779
  },
780
 
781
  exportDocument: async () => {
@@ -807,6 +690,51 @@ type ConfigState = {
807
  setBrushConfig: (config: Partial<ConfigState['brushConfig']>) => void
808
  }
809
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
810
  export const useConfigStore = create<ConfigState>((set) => ({
811
  brushConfig: {
812
  size: 36,
 
1
  'use client'
2
 
3
  import { create } from 'zustand'
4
+ import {
5
+ invoke,
6
+ subscribeProcessProgress,
7
+ getCurrentWindow,
8
+ ProgressBarStatus,
9
+ type ProcessProgress,
10
+ } from '@/lib/backend'
11
  import {
12
  Document,
13
  InpaintRegion,
 
21
  languages: string[]
22
  }
23
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  const createTextBlockSyncer = () => {
25
  let pending: {
26
  index: number
 
205
  index?: number,
206
  textBlockIndex?: number,
207
  ) => Promise<void>
208
+ processImage: (_?: any, index?: number) => Promise<void>
 
 
 
 
209
  inpaintAndRenderImage: (_?: any, index?: number) => Promise<void>
210
  processAllImages: () => Promise<void>
211
  exportDocument: () => Promise<void>
 
609
  void get().renderTextBlock(undefined, index, textBlockIndex)
610
  }
611
  },
612
+ // Auto-processing: delegates to the backend pipeline; progress via SSE
613
+ processImage: async (_, index) => {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
614
  index = index ?? get().currentDocumentIndex
615
+ get().startOperation({
616
+ type: 'process-current',
617
+ cancellable: true,
618
+ current: 0,
619
+ total: 5,
620
+ })
621
+ try {
622
+ await invoke('process', {
623
+ index,
624
+ llmModelId: get().llmSelectedModel,
625
+ language: get().llmSelectedLanguage,
626
+ shaderEffect: get().renderEffect,
627
+ })
628
+ } catch (err) {
629
+ console.error('Failed to start processing:', err)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
630
  get().finishOperation()
 
 
 
631
  await get().clearProgress()
632
  }
 
 
 
 
633
  },
634
 
635
  inpaintAndRenderImage: async (_, index) => {
 
642
  const total = get().totalPages
643
  if (!total) return
644
 
 
 
 
 
 
645
  get().startOperation({
646
  type: 'process-all',
647
  cancellable: true,
648
  current: 0,
649
  total,
650
  })
651
+ try {
652
+ await invoke('process', {
653
+ llmModelId: get().llmSelectedModel,
654
+ language: get().llmSelectedLanguage,
655
+ shaderEffect: get().renderEffect,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
656
  })
657
+ } catch (err) {
658
+ console.error('Failed to start processing:', err)
659
+ get().finishOperation()
660
+ await get().clearProgress()
 
 
 
 
 
 
661
  }
 
 
 
662
  },
663
 
664
  exportDocument: async () => {
 
690
  setBrushConfig: (config: Partial<ConfigState['brushConfig']>) => void
691
  }
692
 
693
+ // Subscribe to pipeline progress at module level so the EventSource is
694
+ // connected before any pipeline is started (avoids race with lazy component mount).
695
+ if (typeof window !== 'undefined') {
696
+ subscribeProcessProgress((progress: ProcessProgress) => {
697
+ const s = useAppStore.getState()
698
+ if (progress.status === 'running') {
699
+ const isSingleDoc = progress.totalDocuments <= 1
700
+ s.updateOperation({
701
+ step: progress.step ?? undefined,
702
+ current: isSingleDoc
703
+ ? progress.currentStepIndex
704
+ : progress.currentDocument +
705
+ (progress.totalSteps > 0
706
+ ? progress.currentStepIndex / progress.totalSteps
707
+ : 0),
708
+ total: isSingleDoc ? progress.totalSteps : progress.totalDocuments,
709
+ })
710
+ getCurrentWindow()
711
+ .setProgressBar({
712
+ status: ProgressBarStatus.Normal,
713
+ progress: progress.overallPercent,
714
+ })
715
+ .catch(() => {})
716
+ s.refreshCurrentDocument()
717
+ } else {
718
+ // Set to 100% first, then wait for the CSS transition to finish
719
+ // before removing the bubble.
720
+ s.updateOperation({
721
+ current: s.operation?.total,
722
+ total: s.operation?.total,
723
+ })
724
+ getCurrentWindow()
725
+ .setProgressBar({ status: ProgressBarStatus.Normal, progress: 100 })
726
+ .catch(() => {})
727
+ s.refreshCurrentDocument()
728
+ setTimeout(() => {
729
+ useAppStore.getState().finishOperation()
730
+ getCurrentWindow()
731
+ .setProgressBar({ status: ProgressBarStatus.None, progress: 0 })
732
+ .catch(() => {})
733
+ }, 1000)
734
+ }
735
+ })
736
+ }
737
+
738
  export const useConfigStore = create<ConfigState>((set) => ({
739
  brushConfig: {
740
  size: 36,