Mayo commited on
Commit
8ebd331
·
1 Parent(s): 296cce1

perf: improve ctd

Browse files
koharu-app/src/llm.rs CHANGED
@@ -26,17 +26,19 @@ use crate::config as app_config;
26
 
27
  pub use koharu_llm::prefetch;
28
 
29
- #[derive(Debug, Clone, Copy, PartialEq, Eq)]
30
- struct BlockStartTag {
31
- offset: usize,
32
- len: usize,
33
- id: usize,
34
- }
35
-
36
- #[derive(Debug, Clone, Copy, PartialEq, Eq)]
37
- struct BlockEndTag {
38
- offset: usize,
39
- len: usize,
 
 
40
  }
41
 
42
  #[allow(clippy::large_enum_variant)]
@@ -125,18 +127,6 @@ fn state_target(state: &State) -> Option<LlmTarget> {
125
  }
126
  }
127
 
128
- fn escape_block_text(text: &str) -> String {
129
- text.replace('&', "&amp;")
130
- .replace('<', "&lt;")
131
- .replace('>', "&gt;")
132
- }
133
-
134
- fn unescape_block_text(text: &str) -> String {
135
- text.replace("&lt;", "<")
136
- .replace("&gt;", ">")
137
- .replace("&amp;", "&")
138
- }
139
-
140
  fn strip_wrapping_quotes(text: &str) -> String {
141
  let mut current = text.trim();
142
 
@@ -194,22 +184,29 @@ fn format_document_blocks(blocks: &[TextBlock]) -> String {
194
  .enumerate()
195
  .map(|(idx, block)| {
196
  let text = block.text.as_deref().unwrap_or("<empty>");
197
- format!(
198
- r#"<block id="{idx}">
199
- {}
200
- </block>"#,
201
- escape_block_text(text)
202
- )
203
  })
204
  .collect::<Vec<_>>()
205
  .join("\n")
206
  }
207
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  fn parse_tagged_blocks(
209
  translation: &str,
210
  expected_blocks: usize,
211
  ) -> anyhow::Result<Option<Vec<String>>> {
212
- if find_next_block_start_tag(translation).is_none() {
213
  return Ok(None);
214
  }
215
 
@@ -219,42 +216,29 @@ fn parse_tagged_blocks(
219
  let mut parsed_count = 0usize;
220
  let mut ignored_count = 0usize;
221
 
222
- while let Some(start_tag) = find_next_block_start_tag(cursor) {
223
  found_any = true;
224
- cursor = &cursor[start_tag.offset + start_tag.len..];
 
 
 
 
 
 
225
 
226
- let id = start_tag.id;
227
  if id >= expected_blocks {
228
  ignored_count += 1;
229
  tracing::warn!("Ignoring translated block id {id} for {expected_blocks} source blocks");
230
- let closing_tag = find_next_block_end_tag(cursor);
231
- let boundary = block_boundary(cursor, closing_tag.map(|tag| tag.offset));
232
- cursor = if closing_tag.map(|tag| tag.offset) == Some(boundary) {
233
- let closing_len = closing_tag.map(|tag| tag.len).unwrap_or(0);
234
- &cursor[boundary + closing_len..]
235
- } else {
236
- &cursor[boundary..]
237
- };
238
- continue;
239
- }
240
-
241
- let closing_tag = find_next_block_end_tag(cursor);
242
- let block_end = block_boundary(cursor, closing_tag.map(|tag| tag.offset));
243
- let content = unescape_block_text(cursor[..block_end].trim());
244
-
245
- if blocks[id].is_empty() {
246
- parsed_count += 1;
247
  } else {
248
- tracing::warn!("Translated block id {id} appeared more than once, keeping latest");
 
 
 
 
 
249
  }
250
- blocks[id] = content;
251
 
252
- cursor = if closing_tag.map(|tag| tag.offset) == Some(block_end) {
253
- let closing_len = closing_tag.map(|tag| tag.len).unwrap_or(0);
254
- &cursor[block_end + closing_len..]
255
- } else {
256
- &cursor[block_end..]
257
- };
258
  }
259
 
260
  if !found_any {
@@ -297,151 +281,6 @@ fn split_legacy_lines(translation: &str, expected_blocks: usize) -> anyhow::Resu
297
  Ok(translations)
298
  }
299
 
300
- fn block_boundary(cursor: &str, closing_tag: Option<usize>) -> usize {
301
- let next_block_start = find_next_block_start_tag(cursor).map(|tag| tag.offset);
302
- match (closing_tag, next_block_start) {
303
- (Some(close), Some(next)) => close.min(next),
304
- (Some(close), None) => close,
305
- (None, Some(next)) => next,
306
- (None, None) => cursor.len(),
307
- }
308
- }
309
-
310
- fn find_next_block_start_tag(text: &str) -> Option<BlockStartTag> {
311
- let mut search_from = 0usize;
312
- while let Some(rel_start) = text[search_from..].find('<') {
313
- let offset = search_from + rel_start;
314
- if let Some((len, id)) = parse_block_start_tag(&text[offset..]) {
315
- return Some(BlockStartTag { offset, len, id });
316
- }
317
- search_from = offset + 1;
318
- }
319
- None
320
- }
321
-
322
- fn parse_block_start_tag(text: &str) -> Option<(usize, usize)> {
323
- let bytes = text.as_bytes();
324
- if bytes.first().copied()? != b'<' {
325
- return None;
326
- }
327
-
328
- let mut index = 1usize;
329
- skip_ascii_whitespace(bytes, &mut index);
330
- if !consume_ascii_keyword(bytes, &mut index, "block") {
331
- return None;
332
- }
333
-
334
- let mut parsed_id = None;
335
- loop {
336
- skip_ascii_whitespace(bytes, &mut index);
337
- match bytes.get(index).copied()? {
338
- b'>' => return parsed_id.map(|id| (index + 1, id)),
339
- b'/' if bytes.get(index + 1).copied() == Some(b'>') => {
340
- return parsed_id.map(|id| (index + 2, id));
341
- }
342
- _ => {}
343
- }
344
-
345
- let name_start = index;
346
- while matches!(
347
- bytes.get(index).copied(),
348
- Some(b'a'..=b'z' | b'A'..=b'Z' | b'0'..=b'9' | b'_' | b'-')
349
- ) {
350
- index += 1;
351
- }
352
- if index == name_start {
353
- return None;
354
- }
355
- let attr_name = &text[name_start..index];
356
-
357
- skip_ascii_whitespace(bytes, &mut index);
358
- if bytes.get(index).copied()? != b'=' {
359
- return None;
360
- }
361
- index += 1;
362
- skip_ascii_whitespace(bytes, &mut index);
363
-
364
- let attr_value = match bytes.get(index).copied()? {
365
- b'"' | b'\'' => {
366
- let quote = bytes[index];
367
- index += 1;
368
- let value_start = index;
369
- while bytes.get(index).copied()? != quote {
370
- index += 1;
371
- }
372
- let value = &text[value_start..index];
373
- index += 1;
374
- value
375
- }
376
- _ => {
377
- let value_start = index;
378
- while matches!(bytes.get(index).copied(), Some(byte) if !byte.is_ascii_whitespace() && byte != b'>')
379
- {
380
- index += 1;
381
- }
382
- &text[value_start..index]
383
- }
384
- };
385
-
386
- if attr_name.eq_ignore_ascii_case("id") {
387
- parsed_id = attr_value.parse::<usize>().ok();
388
- }
389
- }
390
- }
391
-
392
- fn find_next_block_end_tag(text: &str) -> Option<BlockEndTag> {
393
- let mut search_from = 0usize;
394
- while let Some(rel_start) = text[search_from..].find('<') {
395
- let offset = search_from + rel_start;
396
- if let Some(len) = parse_block_end_tag(&text[offset..]) {
397
- return Some(BlockEndTag { offset, len });
398
- }
399
- search_from = offset + 1;
400
- }
401
- None
402
- }
403
-
404
- fn parse_block_end_tag(text: &str) -> Option<usize> {
405
- let bytes = text.as_bytes();
406
- if bytes.first().copied()? != b'<' {
407
- return None;
408
- }
409
-
410
- let mut index = 1usize;
411
- skip_ascii_whitespace(bytes, &mut index);
412
- if bytes.get(index).copied()? != b'/' {
413
- return None;
414
- }
415
- index += 1;
416
- skip_ascii_whitespace(bytes, &mut index);
417
- if !consume_ascii_keyword(bytes, &mut index, "block") {
418
- return None;
419
- }
420
- skip_ascii_whitespace(bytes, &mut index);
421
- if bytes.get(index).copied()? != b'>' {
422
- return None;
423
- }
424
- Some(index + 1)
425
- }
426
-
427
- fn skip_ascii_whitespace(bytes: &[u8], index: &mut usize) {
428
- while matches!(bytes.get(*index).copied(), Some(byte) if byte.is_ascii_whitespace()) {
429
- *index += 1;
430
- }
431
- }
432
-
433
- fn consume_ascii_keyword(bytes: &[u8], index: &mut usize, keyword: &str) -> bool {
434
- let end = *index + keyword.len();
435
- let Some(slice) = bytes.get(*index..end) else {
436
- return false;
437
- };
438
- if !slice.eq_ignore_ascii_case(keyword.as_bytes()) {
439
- return false;
440
- }
441
- *index = end;
442
- true
443
- }
444
-
445
  impl Translatable for Document {
446
  fn get_source(&self) -> anyhow::Result<String> {
447
  Ok(format_document_blocks(&self.text_blocks))
@@ -466,12 +305,7 @@ impl Translatable for TextBlock {
466
  .text
467
  .clone()
468
  .ok_or_else(|| anyhow::anyhow!("No source text found"))?;
469
- Ok(format!(
470
- r#"<block id="0">
471
- {}
472
- </block>"#,
473
- escape_block_text(&source)
474
- ))
475
  }
476
 
477
  fn set_translation(&mut self, translation: String) -> anyhow::Result<()> {
@@ -929,10 +763,7 @@ mod tests {
929
  };
930
 
931
  let source = doc.get_source()?;
932
- assert_eq!(
933
- source,
934
- "<block id=\"0\">\nHello\n</block>\n<block id=\"1\">\n1 &lt; 2\nA &amp; B\n</block>"
935
- );
936
 
937
  Ok(())
938
  }
@@ -944,9 +775,7 @@ mod tests {
944
  ..Default::default()
945
  };
946
 
947
- doc.set_translation(
948
- "<block id=\"1\">\nSecond line\nnext\n</block>\n<block id=\"0\">\nFirst &lt;done&gt;\n</block>".to_string(),
949
- )?;
950
 
951
  assert_eq!(
952
  doc.text_blocks[0].translation.as_deref(),
@@ -967,10 +796,7 @@ mod tests {
967
  ..Default::default()
968
  };
969
 
970
- doc.set_translation(
971
- "<block id=\"0\">\n\"Hello\"\n</block>\n<block id=\"1\">\n“World”\n</block>"
972
- .to_string(),
973
- )?;
974
 
975
  assert_eq!(doc.text_blocks[0].translation.as_deref(), Some("Hello"));
976
  assert_eq!(doc.text_blocks[1].translation.as_deref(), Some("World"));
@@ -997,15 +823,13 @@ mod tests {
997
  }
998
 
999
  #[test]
1000
- fn document_translation_allows_missing_closing_tags() -> anyhow::Result<()> {
1001
  let mut doc = Document {
1002
  text_blocks: vec![TextBlock::default(), TextBlock::default()],
1003
  ..Default::default()
1004
  };
1005
 
1006
- doc.set_translation(
1007
- "<block id=\"0\">\nFirst line\n<block id=\"1\">\nSecond line".to_string(),
1008
- )?;
1009
 
1010
  assert_eq!(
1011
  doc.text_blocks[0].translation.as_deref(),
@@ -1020,14 +844,13 @@ mod tests {
1020
  }
1021
 
1022
  #[test]
1023
- fn document_translation_uses_end_of_text_when_last_closing_tag_is_missing() -> anyhow::Result<()>
1024
- {
1025
  let mut doc = Document {
1026
  text_blocks: vec![TextBlock::default()],
1027
  ..Default::default()
1028
  };
1029
 
1030
- doc.set_translation("<block id=\"0\">\nFinal line".to_string())?;
1031
 
1032
  assert_eq!(
1033
  doc.text_blocks[0].translation.as_deref(),
@@ -1044,49 +867,13 @@ mod tests {
1044
  ..Default::default()
1045
  };
1046
 
1047
- doc.set_translation(
1048
- "<block id=\"0\">\nKept\n</block>\n<block id=\"1\">\nIgnored\n</block>".to_string(),
1049
- )?;
1050
 
1051
  assert_eq!(doc.text_blocks[0].translation.as_deref(), Some("Kept"));
1052
 
1053
  Ok(())
1054
  }
1055
 
1056
- #[test]
1057
- fn document_translation_accepts_relaxed_block_tag_formatting() -> anyhow::Result<()> {
1058
- let mut doc = Document {
1059
- text_blocks: vec![TextBlock::default(), TextBlock::default()],
1060
- ..Default::default()
1061
- };
1062
-
1063
- doc.set_translation(
1064
- "<block id = '1' >\nSecond\n</ block>\n<Block id=0>\nFirst\n</BLOCK>".to_string(),
1065
- )?;
1066
-
1067
- assert_eq!(doc.text_blocks[0].translation.as_deref(), Some("First"));
1068
- assert_eq!(doc.text_blocks[1].translation.as_deref(), Some("Second"));
1069
-
1070
- Ok(())
1071
- }
1072
-
1073
- #[test]
1074
- fn document_translation_accepts_unquoted_block_ids() -> anyhow::Result<()> {
1075
- let mut doc = Document {
1076
- text_blocks: vec![TextBlock::default()],
1077
- ..Default::default()
1078
- };
1079
-
1080
- doc.set_translation("<block id=0>\nOnly first\n</block>".to_string())?;
1081
-
1082
- assert_eq!(
1083
- doc.text_blocks[0].translation.as_deref(),
1084
- Some("Only first")
1085
- );
1086
-
1087
- Ok(())
1088
- }
1089
-
1090
  #[test]
1091
  fn document_translation_pads_missing_tagged_blocks() -> anyhow::Result<()> {
1092
  let mut doc = Document {
@@ -1094,7 +881,7 @@ mod tests {
1094
  ..Default::default()
1095
  };
1096
 
1097
- doc.set_translation("<block id=\"0\">\nOnly first\n</block>".to_string())?;
1098
 
1099
  assert_eq!(
1100
  doc.text_blocks[0].translation.as_deref(),
@@ -1108,7 +895,7 @@ mod tests {
1108
  #[test]
1109
  fn text_block_translation_strips_wrapping_quotes() -> anyhow::Result<()> {
1110
  let mut block = TextBlock::default();
1111
- block.set_translation("quoted".to_string())?;
1112
  assert_eq!(block.translation.as_deref(), Some("quoted"));
1113
  Ok(())
1114
  }
@@ -1121,7 +908,7 @@ mod tests {
1121
  };
1122
 
1123
  let source = block.get_source()?;
1124
- assert_eq!(source, "<block id=\"0\">\n1 &lt; 2\nA &amp; B\n</block>");
1125
 
1126
  Ok(())
1127
  }
@@ -1129,10 +916,8 @@ mod tests {
1129
  #[test]
1130
  fn text_block_translation_extracts_tagged_block_content() -> anyhow::Result<()> {
1131
  let mut block = TextBlock::default();
1132
- block.set_translation(
1133
- "Sure.\n<block id=\"0\">\nTranslated &lt;line&gt;\n</block>\nDone.".to_string(),
1134
- )?;
1135
- assert_eq!(block.translation.as_deref(), Some("Translated <line>"));
1136
  Ok(())
1137
  }
1138
 
@@ -1150,8 +935,23 @@ mod tests {
1150
  #[test]
1151
  fn text_block_translation_keeps_japanese_dialogue_quotes() -> anyhow::Result<()> {
1152
  let mut block = TextBlock::default();
1153
- block.set_translation("quoted".to_string())?;
1154
- assert_eq!(block.translation.as_deref(), Some("quoted"));
1155
  Ok(())
1156
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1157
  }
 
26
 
27
  pub use koharu_llm::prefetch;
28
 
29
+ /// Matches a `<|N|>` tag, returning (full match length, 0-based index).
30
+ fn parse_block_tag(text: &str) -> Option<(usize, usize)> {
31
+ let bytes = text.as_bytes();
32
+ if bytes.get(0..2)? != b"<|" {
33
+ return None;
34
+ }
35
+ let end = text[2..].find("|>")?;
36
+ let num_str = &text[2..2 + end];
37
+ let id_1based: usize = num_str.parse().ok()?;
38
+ if id_1based == 0 {
39
+ return None;
40
+ }
41
+ Some((2 + end + 2, id_1based - 1))
42
  }
43
 
44
  #[allow(clippy::large_enum_variant)]
 
127
  }
128
  }
129
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  fn strip_wrapping_quotes(text: &str) -> String {
131
  let mut current = text.trim();
132
 
 
184
  .enumerate()
185
  .map(|(idx, block)| {
186
  let text = block.text.as_deref().unwrap_or("<empty>");
187
+ format!("<|{}|>{}", idx + 1, text)
 
 
 
 
 
188
  })
189
  .collect::<Vec<_>>()
190
  .join("\n")
191
  }
192
 
193
+ fn find_next_tag(text: &str) -> Option<(usize, usize, usize)> {
194
+ let mut pos = 0;
195
+ while let Some(rel) = text[pos..].find("<|") {
196
+ let offset = pos + rel;
197
+ if let Some((len, id)) = parse_block_tag(&text[offset..]) {
198
+ return Some((offset, len, id));
199
+ }
200
+ pos = offset + 1;
201
+ }
202
+ None
203
+ }
204
+
205
  fn parse_tagged_blocks(
206
  translation: &str,
207
  expected_blocks: usize,
208
  ) -> anyhow::Result<Option<Vec<String>>> {
209
+ if find_next_tag(translation).is_none() {
210
  return Ok(None);
211
  }
212
 
 
216
  let mut parsed_count = 0usize;
217
  let mut ignored_count = 0usize;
218
 
219
+ while let Some((offset, len, id)) = find_next_tag(cursor) {
220
  found_any = true;
221
+ cursor = &cursor[offset + len..];
222
+
223
+ // Find content: everything until the next tag or end of string
224
+ let content_end = find_next_tag(cursor)
225
+ .map(|(next_offset, _, _)| next_offset)
226
+ .unwrap_or(cursor.len());
227
+ let content = cursor[..content_end].trim().to_string();
228
 
 
229
  if id >= expected_blocks {
230
  ignored_count += 1;
231
  tracing::warn!("Ignoring translated block id {id} for {expected_blocks} source blocks");
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  } else {
233
+ if blocks[id].is_empty() {
234
+ parsed_count += 1;
235
+ } else {
236
+ tracing::warn!("Translated block id {id} appeared more than once, keeping latest");
237
+ }
238
+ blocks[id] = content;
239
  }
 
240
 
241
+ cursor = &cursor[content_end..];
 
 
 
 
 
242
  }
243
 
244
  if !found_any {
 
281
  Ok(translations)
282
  }
283
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
  impl Translatable for Document {
285
  fn get_source(&self) -> anyhow::Result<String> {
286
  Ok(format_document_blocks(&self.text_blocks))
 
305
  .text
306
  .clone()
307
  .ok_or_else(|| anyhow::anyhow!("No source text found"))?;
308
+ Ok(format!("<|1|>{}", source))
 
 
 
 
 
309
  }
310
 
311
  fn set_translation(&mut self, translation: String) -> anyhow::Result<()> {
 
763
  };
764
 
765
  let source = doc.get_source()?;
766
+ assert_eq!(source, "<|1|>Hello\n<|2|>1 < 2\nA & B");
 
 
 
767
 
768
  Ok(())
769
  }
 
775
  ..Default::default()
776
  };
777
 
778
+ doc.set_translation("<|2|>Second line\nnext\n<|1|>First <done>".to_string())?;
 
 
779
 
780
  assert_eq!(
781
  doc.text_blocks[0].translation.as_deref(),
 
796
  ..Default::default()
797
  };
798
 
799
+ doc.set_translation("<|1|>\"Hello\"\n<|2|>\u{201c}World\u{201d}".to_string())?;
 
 
 
800
 
801
  assert_eq!(doc.text_blocks[0].translation.as_deref(), Some("Hello"));
802
  assert_eq!(doc.text_blocks[1].translation.as_deref(), Some("World"));
 
823
  }
824
 
825
  #[test]
826
+ fn document_translation_parses_consecutive_tags() -> anyhow::Result<()> {
827
  let mut doc = Document {
828
  text_blocks: vec![TextBlock::default(), TextBlock::default()],
829
  ..Default::default()
830
  };
831
 
832
+ doc.set_translation("<|1|>First line\n<|2|>Second line".to_string())?;
 
 
833
 
834
  assert_eq!(
835
  doc.text_blocks[0].translation.as_deref(),
 
844
  }
845
 
846
  #[test]
847
+ fn document_translation_uses_end_of_text_for_last_block() -> anyhow::Result<()> {
 
848
  let mut doc = Document {
849
  text_blocks: vec![TextBlock::default()],
850
  ..Default::default()
851
  };
852
 
853
+ doc.set_translation("<|1|>Final line".to_string())?;
854
 
855
  assert_eq!(
856
  doc.text_blocks[0].translation.as_deref(),
 
867
  ..Default::default()
868
  };
869
 
870
+ doc.set_translation("<|1|>Kept\n<|2|>Ignored".to_string())?;
 
 
871
 
872
  assert_eq!(doc.text_blocks[0].translation.as_deref(), Some("Kept"));
873
 
874
  Ok(())
875
  }
876
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
877
  #[test]
878
  fn document_translation_pads_missing_tagged_blocks() -> anyhow::Result<()> {
879
  let mut doc = Document {
 
881
  ..Default::default()
882
  };
883
 
884
+ doc.set_translation("<|1|>Only first".to_string())?;
885
 
886
  assert_eq!(
887
  doc.text_blocks[0].translation.as_deref(),
 
895
  #[test]
896
  fn text_block_translation_strips_wrapping_quotes() -> anyhow::Result<()> {
897
  let mut block = TextBlock::default();
898
+ block.set_translation("\u{201c}quoted\u{201d}".to_string())?;
899
  assert_eq!(block.translation.as_deref(), Some("quoted"));
900
  Ok(())
901
  }
 
908
  };
909
 
910
  let source = block.get_source()?;
911
+ assert_eq!(source, "<|1|>1 < 2\nA & B");
912
 
913
  Ok(())
914
  }
 
916
  #[test]
917
  fn text_block_translation_extracts_tagged_block_content() -> anyhow::Result<()> {
918
  let mut block = TextBlock::default();
919
+ block.set_translation("Sure.\n<|1|>Translated text".to_string())?;
920
+ assert_eq!(block.translation.as_deref(), Some("Translated text"));
 
 
921
  Ok(())
922
  }
923
 
 
935
  #[test]
936
  fn text_block_translation_keeps_japanese_dialogue_quotes() -> anyhow::Result<()> {
937
  let mut block = TextBlock::default();
938
+ block.set_translation("\u{300c}quoted\u{300d}".to_string())?;
939
+ assert_eq!(block.translation.as_deref(), Some("\u{300c}quoted\u{300d}"));
940
  Ok(())
941
  }
942
+
943
+ #[test]
944
+ fn parse_block_tag_parses_valid_tags() {
945
+ assert_eq!(parse_block_tag("<|1|>"), Some((5, 0)));
946
+ assert_eq!(parse_block_tag("<|2|>"), Some((5, 1)));
947
+ assert_eq!(parse_block_tag("<|10|>"), Some((6, 9)));
948
+ }
949
+
950
+ #[test]
951
+ fn parse_block_tag_rejects_invalid_tags() {
952
+ assert_eq!(parse_block_tag("<|0|>"), None);
953
+ assert_eq!(parse_block_tag("<|abc|>"), None);
954
+ assert_eq!(parse_block_tag("<block>"), None);
955
+ assert_eq!(parse_block_tag("hello"), None);
956
+ }
957
  }
koharu-app/src/renderer.rs CHANGED
@@ -153,17 +153,25 @@ impl Renderer {
153
  .fontbook
154
  .lock()
155
  .map_err(|_| anyhow::anyhow!("Failed to lock fontbook"))?;
 
156
  let mut fonts = fontbook
157
  .all_families()
158
  .into_iter()
159
  .filter(|face| !face.post_script_name.is_empty())
160
- .map(|face| FontFaceInfo {
161
- family_name: face
162
  .families
163
  .first()
164
  .map(|(family, _)| family.clone())
165
- .unwrap_or_else(|| face.post_script_name.clone()),
166
- post_script_name: face.post_script_name,
 
 
 
 
 
 
 
167
  })
168
  .collect::<Vec<_>>();
169
  fonts.sort();
@@ -201,6 +209,10 @@ impl Renderer {
201
  None => clustered_sizes,
202
  };
203
 
 
 
 
 
204
  let block_renders: Vec<Option<(DynamicImage, usize)>> = {
205
  use rayon::prelude::*;
206
  let span = tracing::info_span!("render_blocks", count = blocks_to_render.len());
@@ -212,12 +224,13 @@ impl Renderer {
212
  .map(|(i, text_block)| {
213
  let _guard = parent_span.enter();
214
  let base_font_size = render_font_sizes.get(i).copied().flatten();
 
215
  match self.render_text_block(
216
  text_block,
217
  opts.shader_effect,
218
  opts.shader_stroke.clone(),
219
  opts.font_family,
220
- opts.bubbles,
221
  base_font_size,
222
  min_font,
223
  ) {
@@ -256,12 +269,9 @@ impl Renderer {
256
  continue;
257
  };
258
  let block_img = images.load(blob_ref)?;
259
- imageops::overlay(
260
- &mut rendered,
261
- &block_img,
262
- text_block.x as i64,
263
- text_block.y as i64,
264
- );
265
  }
266
  let rendered_ref = images.store_raw(&DynamicImage::from(rendered))?;
267
  return Ok(Some(rendered_ref));
@@ -276,7 +286,7 @@ impl Renderer {
276
  effect: TextShaderEffect,
277
  global_stroke: Option<TextStrokeStyle>,
278
  font_family: Option<&str>,
279
- bubbles: &[koharu_core::BubbleRegion],
280
  base_font_size: Option<f32>,
281
  min_font_size: f32,
282
  ) -> Result<Option<DynamicImage>> {
@@ -340,23 +350,16 @@ impl Renderer {
340
  }
341
  });
342
  let block_box = layout_box_from_block(&layout_source_block);
343
- let bubble_box = find_best_bubble(text_block, bubbles);
344
- // Use whichever is larger — bubble or text block box.
345
- let layout_box = match bubble_box {
346
- Some(bb) => LayoutBox {
347
- x: block_box.x.min(bb.x),
348
- y: block_box.y.min(bb.y),
349
- width: block_box.width.max(bb.width),
350
- height: block_box.height.max(bb.height),
351
- },
352
- None => block_box,
353
  };
354
 
355
- // Determine base font size: user-set > clustered detected.
356
- // If neither is available, skip rendering (no font size info).
357
- let Some(effective_base_size) = style.font_size.or(base_font_size) else {
358
- return Ok(None);
359
- };
360
 
361
  let layout_builder = TextLayout::new(&font, None)
362
  .with_fallback_fonts(&self.symbol_fallbacks)
@@ -383,6 +386,7 @@ impl Renderer {
383
  style.stroke.as_ref(),
384
  global_stroke.as_ref(),
385
  layout.font_size,
 
386
  );
387
  let rendered = {
388
  let _s = tracing::info_span!("rasterize").entered();
@@ -399,14 +403,22 @@ impl Renderer {
399
  )?
400
  };
401
 
402
- text_block.x = layout_box.x;
403
- text_block.y = layout_box.y;
404
- text_block.width = layout_box.width;
405
- text_block.height = layout_box.height;
406
  text_block.rendered_direction = Some(match writing_mode {
407
  WritingMode::Horizontal => koharu_core::TextDirection::Horizontal,
408
  WritingMode::VerticalRl => koharu_core::TextDirection::Vertical,
409
  });
 
 
 
 
 
 
 
 
 
 
 
 
410
  // rendered field will be set by the caller with a BlobRef
411
  let persisted_style = text_block.style.get_or_insert_with(|| TextStyle {
412
  font_families: Vec::new(),
@@ -455,18 +467,41 @@ fn apply_default_font_families(font_families: &mut Vec<String>, text: &str) {
455
  }
456
  }
457
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
458
  fn resolve_stroke_style(
459
  block: &TextBlock,
460
  block_stroke: Option<&TextStrokeStyle>,
461
  global_stroke: Option<&TextStrokeStyle>,
462
  font_size: f32,
 
463
  ) -> Option<RenderStrokeOptions> {
464
  if let Some(stroke) = block_stroke {
465
  if !stroke.enabled {
466
  return None;
467
  }
 
 
 
 
 
468
  return Some(RenderStrokeOptions {
469
- color: stroke.color,
470
  width_px: stroke
471
  .width_px
472
  .unwrap_or_else(|| default_stroke_width(font_size)),
@@ -477,8 +512,13 @@ fn resolve_stroke_style(
477
  if !stroke.enabled {
478
  return None;
479
  }
 
 
 
 
 
480
  return Some(RenderStrokeOptions {
481
- color: stroke.color,
482
  width_px: stroke
483
  .width_px
484
  .unwrap_or_else(|| default_stroke_width(font_size)),
@@ -488,19 +528,24 @@ fn resolve_stroke_style(
488
  if let Some(pred) = &block.font_prediction
489
  && pred.stroke_width_px > 0.0
490
  {
 
 
 
 
 
 
491
  return Some(RenderStrokeOptions {
492
- color: [
493
- pred.stroke_color[0],
494
- pred.stroke_color[1],
495
- pred.stroke_color[2],
496
- 255,
497
- ],
498
  width_px: pred.stroke_width_px,
499
  });
500
  }
501
 
502
  Some(RenderStrokeOptions {
503
- color: [255, 255, 255, 255],
504
  width_px: default_stroke_width(font_size),
505
  })
506
  }
@@ -598,47 +643,88 @@ fn face_post_script_name(faces: &[FaceInfo], candidate: &str) -> Option<String>
598
  .filter(|post_script_name| !post_script_name.is_empty())
599
  }
600
 
601
- /// Find the bubble with the best IoU overlap for a text block.
602
- /// Returns the bubble's bbox as a LayoutBox if a good match is found.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
603
  fn find_best_bubble(block: &TextBlock, bubbles: &[koharu_core::BubbleRegion]) -> Option<LayoutBox> {
604
  if bubbles.is_empty() {
605
  return None;
606
  }
607
-
 
 
608
  let block_area = (block.width * block.height).max(1.0);
609
  let mut best: Option<(f32, &koharu_core::BubbleRegion)> = None;
610
-
611
  for bubble in bubbles {
 
 
 
 
 
 
 
 
612
  let bubble_area = (bubble.width * bubble.height).max(1.0);
613
-
614
- // Intersection
615
  let ix0 = block.x.max(bubble.x);
616
  let iy0 = block.y.max(bubble.y);
617
  let ix1 = (block.x + block.width).min(bubble.x + bubble.width);
618
  let iy1 = (block.y + block.height).min(bubble.y + bubble.height);
619
  let inter = (ix1 - ix0).max(0.0) * (iy1 - iy0).max(0.0);
620
-
621
- // IoU
622
  let union = block_area + bubble_area - inter;
623
  let iou = if union > 0.0 { inter / union } else { 0.0 };
624
-
625
- if let Some((best_iou, _)) = &best {
626
- if iou > *best_iou {
627
- best = Some((iou, bubble));
628
- }
629
- } else if iou > 0.0 {
630
- best = Some((iou, bubble));
631
  }
632
  }
633
-
634
- // Require minimum overlap to avoid false matches
635
- best.filter(|(iou, _)| *iou > 0.05)
636
- .map(|(_, bubble)| LayoutBox {
637
- x: bubble.x,
638
- y: bubble.y,
639
- width: bubble.width,
640
- height: bubble.height,
641
- })
642
  }
643
 
644
  #[cfg(test)]
 
153
  .fontbook
154
  .lock()
155
  .map_err(|_| anyhow::anyhow!("Failed to lock fontbook"))?;
156
+ let mut seen = std::collections::HashSet::new();
157
  let mut fonts = fontbook
158
  .all_families()
159
  .into_iter()
160
  .filter(|face| !face.post_script_name.is_empty())
161
+ .filter_map(|face| {
162
+ let family_name = face
163
  .families
164
  .first()
165
  .map(|(family, _)| family.clone())
166
+ .unwrap_or_else(|| face.post_script_name.clone());
167
+ if seen.insert(family_name.clone()) {
168
+ Some(FontFaceInfo {
169
+ family_name,
170
+ post_script_name: face.post_script_name,
171
+ })
172
+ } else {
173
+ None
174
+ }
175
  })
176
  .collect::<Vec<_>>();
177
  fonts.sort();
 
209
  None => clustered_sizes,
210
  };
211
 
212
+ // Bubble expansion disabled — causes text to fly away from block position.
213
+ // Bubbles are used for font size estimation via detection, not layout.
214
+ let render_areas: Vec<Option<LayoutBox>> = vec![None; blocks_to_render.len()];
215
+
216
  let block_renders: Vec<Option<(DynamicImage, usize)>> = {
217
  use rayon::prelude::*;
218
  let span = tracing::info_span!("render_blocks", count = blocks_to_render.len());
 
224
  .map(|(i, text_block)| {
225
  let _guard = parent_span.enter();
226
  let base_font_size = render_font_sizes.get(i).copied().flatten();
227
+ let bubble_area = render_areas.get(i).copied().flatten();
228
  match self.render_text_block(
229
  text_block,
230
  opts.shader_effect,
231
  opts.shader_stroke.clone(),
232
  opts.font_family,
233
+ bubble_area,
234
  base_font_size,
235
  min_font,
236
  ) {
 
269
  continue;
270
  };
271
  let block_img = images.load(blob_ref)?;
272
+ let rx = text_block.render_x.unwrap_or(text_block.x);
273
+ let ry = text_block.render_y.unwrap_or(text_block.y);
274
+ imageops::overlay(&mut rendered, &block_img, rx as i64, ry as i64);
 
 
 
275
  }
276
  let rendered_ref = images.store_raw(&DynamicImage::from(rendered))?;
277
  return Ok(Some(rendered_ref));
 
286
  effect: TextShaderEffect,
287
  global_stroke: Option<TextStrokeStyle>,
288
  font_family: Option<&str>,
289
+ bubble_area: Option<LayoutBox>,
290
  base_font_size: Option<f32>,
291
  min_font_size: f32,
292
  ) -> Result<Option<DynamicImage>> {
 
350
  }
351
  });
352
  let block_box = layout_box_from_block(&layout_source_block);
353
+ let (layout_box, bubble_expanded) = match bubble_area {
354
+ Some(area) => (area, true),
355
+ None => (block_box, false),
 
 
 
 
 
 
 
356
  };
357
 
358
+ // Determine base font size: user-set > clustered detected > fallback.
359
+ let effective_base_size = style.font_size.or(base_font_size).unwrap_or_else(|| {
360
+ // Fallback: estimate from layout box height.
361
+ (layout_box.height * 0.3).clamp(min_font_size, 60.0)
362
+ });
363
 
364
  let layout_builder = TextLayout::new(&font, None)
365
  .with_fallback_fonts(&self.symbol_fallbacks)
 
386
  style.stroke.as_ref(),
387
  global_stroke.as_ref(),
388
  layout.font_size,
389
+ color,
390
  );
391
  let rendered = {
392
  let _s = tracing::info_span!("rasterize").entered();
 
403
  )?
404
  };
405
 
 
 
 
 
406
  text_block.rendered_direction = Some(match writing_mode {
407
  WritingMode::Horizontal => koharu_core::TextDirection::Horizontal,
408
  WritingMode::VerticalRl => koharu_core::TextDirection::Vertical,
409
  });
410
+ // Store actual render area when bubble expansion was used.
411
+ if bubble_expanded {
412
+ text_block.render_x = Some(layout_box.x.round());
413
+ text_block.render_y = Some(layout_box.y.round());
414
+ text_block.render_width = Some(layout_box.width.round());
415
+ text_block.render_height = Some(layout_box.height.round());
416
+ } else {
417
+ text_block.render_x = None;
418
+ text_block.render_y = None;
419
+ text_block.render_width = None;
420
+ text_block.render_height = None;
421
+ }
422
  // rendered field will be set by the caller with a BlobRef
423
  let persisted_style = text_block.style.get_or_insert_with(|| TextStyle {
424
  font_families: Vec::new(),
 
467
  }
468
  }
469
 
470
+ fn colors_too_similar(a: [u8; 4], b: [u8; 4]) -> bool {
471
+ let dr = (a[0] as i32 - b[0] as i32).abs();
472
+ let dg = (a[1] as i32 - b[1] as i32).abs();
473
+ let db = (a[2] as i32 - b[2] as i32).abs();
474
+ dr + dg + db < 60
475
+ }
476
+
477
+ fn contrasting_stroke_color(text_color: [u8; 4]) -> [u8; 4] {
478
+ let luminance =
479
+ 0.299 * text_color[0] as f32 + 0.587 * text_color[1] as f32 + 0.114 * text_color[2] as f32;
480
+ if luminance > 128.0 {
481
+ [0, 0, 0, 255]
482
+ } else {
483
+ [255, 255, 255, 255]
484
+ }
485
+ }
486
+
487
  fn resolve_stroke_style(
488
  block: &TextBlock,
489
  block_stroke: Option<&TextStrokeStyle>,
490
  global_stroke: Option<&TextStrokeStyle>,
491
  font_size: f32,
492
+ text_color: [u8; 4],
493
  ) -> Option<RenderStrokeOptions> {
494
  if let Some(stroke) = block_stroke {
495
  if !stroke.enabled {
496
  return None;
497
  }
498
+ let color = if colors_too_similar(text_color, stroke.color) {
499
+ contrasting_stroke_color(text_color)
500
+ } else {
501
+ stroke.color
502
+ };
503
  return Some(RenderStrokeOptions {
504
+ color,
505
  width_px: stroke
506
  .width_px
507
  .unwrap_or_else(|| default_stroke_width(font_size)),
 
512
  if !stroke.enabled {
513
  return None;
514
  }
515
+ let color = if colors_too_similar(text_color, stroke.color) {
516
+ contrasting_stroke_color(text_color)
517
+ } else {
518
+ stroke.color
519
+ };
520
  return Some(RenderStrokeOptions {
521
+ color,
522
  width_px: stroke
523
  .width_px
524
  .unwrap_or_else(|| default_stroke_width(font_size)),
 
528
  if let Some(pred) = &block.font_prediction
529
  && pred.stroke_width_px > 0.0
530
  {
531
+ let pred_stroke = [
532
+ pred.stroke_color[0],
533
+ pred.stroke_color[1],
534
+ pred.stroke_color[2],
535
+ 255,
536
+ ];
537
  return Some(RenderStrokeOptions {
538
+ color: if colors_too_similar(text_color, pred_stroke) {
539
+ contrasting_stroke_color(text_color)
540
+ } else {
541
+ pred_stroke
542
+ },
 
543
  width_px: pred.stroke_width_px,
544
  });
545
  }
546
 
547
  Some(RenderStrokeOptions {
548
+ color: contrasting_stroke_color(text_color),
549
  width_px: default_stroke_width(font_size),
550
  })
551
  }
 
643
  .filter(|post_script_name| !post_script_name.is_empty())
644
  }
645
 
646
+ #[allow(dead_code)]
647
+ fn compute_render_areas(
648
+ blocks: &[&mut TextBlock],
649
+ bubbles: &[koharu_core::BubbleRegion],
650
+ ) -> Vec<Option<LayoutBox>> {
651
+ // First pass: compute bubble-expanded area for each block.
652
+ let mut areas: Vec<Option<LayoutBox>> = blocks
653
+ .iter()
654
+ .map(|block| {
655
+ if block.lock_layout_box {
656
+ return None;
657
+ }
658
+ let bubble = find_best_bubble(block, bubbles)?;
659
+ let pad_x = bubble.width * 0.05;
660
+ let pad_y = bubble.height * 0.05;
661
+ Some(LayoutBox {
662
+ x: (bubble.x + pad_x).round(),
663
+ y: (bubble.y + pad_y).round(),
664
+ width: (bubble.width - pad_x * 2.0).round().max(1.0),
665
+ height: (bubble.height - pad_y * 2.0).round().max(1.0),
666
+ })
667
+ })
668
+ .collect();
669
+
670
+ // Second pass: if any two expanded areas overlap, fall back both to block dims.
671
+ for i in 0..areas.len() {
672
+ for j in (i + 1)..areas.len() {
673
+ let (Some(a), Some(b)) = (areas[i], areas[j]) else {
674
+ continue;
675
+ };
676
+ let overlap_x = (a.x + a.width).min(b.x + b.width) - a.x.max(b.x);
677
+ let overlap_y = (a.y + a.height).min(b.y + b.height) - a.y.max(b.y);
678
+ if overlap_x > 0.0 && overlap_y > 0.0 {
679
+ areas[i] = None;
680
+ areas[j] = None;
681
+ break;
682
+ }
683
+ }
684
+ }
685
+
686
+ areas
687
+ }
688
+
689
+ #[allow(dead_code)]
690
  fn find_best_bubble(block: &TextBlock, bubbles: &[koharu_core::BubbleRegion]) -> Option<LayoutBox> {
691
  if bubbles.is_empty() {
692
  return None;
693
  }
694
+ // Block center must be inside the bubble.
695
+ let cx = block.x + block.width * 0.5;
696
+ let cy = block.y + block.height * 0.5;
697
  let block_area = (block.width * block.height).max(1.0);
698
  let mut best: Option<(f32, &koharu_core::BubbleRegion)> = None;
 
699
  for bubble in bubbles {
700
+ // Check containment: block center inside bubble.
701
+ if cx < bubble.x
702
+ || cx > bubble.x + bubble.width
703
+ || cy < bubble.y
704
+ || cy > bubble.y + bubble.height
705
+ {
706
+ continue;
707
+ }
708
  let bubble_area = (bubble.width * bubble.height).max(1.0);
 
 
709
  let ix0 = block.x.max(bubble.x);
710
  let iy0 = block.y.max(bubble.y);
711
  let ix1 = (block.x + block.width).min(bubble.x + bubble.width);
712
  let iy1 = (block.y + block.height).min(bubble.y + bubble.height);
713
  let inter = (ix1 - ix0).max(0.0) * (iy1 - iy0).max(0.0);
 
 
714
  let union = block_area + bubble_area - inter;
715
  let iou = if union > 0.0 { inter / union } else { 0.0 };
716
+ match &best {
717
+ Some((best_iou, _)) if iou > *best_iou => best = Some((iou, bubble)),
718
+ None if iou > 0.0 => best = Some((iou, bubble)),
719
+ _ => {}
 
 
 
720
  }
721
  }
722
+ best.filter(|(iou, _)| *iou > 0.1).map(|(_, b)| LayoutBox {
723
+ x: b.x,
724
+ y: b.y,
725
+ width: b.width,
726
+ height: b.height,
727
+ })
 
 
 
728
  }
729
 
730
  #[cfg(test)]
koharu-core/src/lib.rs CHANGED
@@ -71,6 +71,16 @@ pub struct TextBlock {
71
  pub rendered: Option<BlobRef>,
72
  #[serde(default)]
73
  pub lock_layout_box: bool,
 
 
 
 
 
 
 
 
 
 
74
  }
75
 
76
  impl Default for TextBlock {
@@ -95,6 +105,10 @@ impl Default for TextBlock {
95
  font_prediction: None,
96
  rendered: None,
97
  lock_layout_box: false,
 
 
 
 
98
  }
99
  }
100
  }
 
71
  pub rendered: Option<BlobRef>,
72
  #[serde(default)]
73
  pub lock_layout_box: bool,
74
+ /// Actual render area — set by renderer when bubble expansion is used.
75
+ /// Frontend and composite use these for sprite positioning when present.
76
+ #[serde(default, skip_serializing_if = "Option::is_none")]
77
+ pub render_x: Option<f32>,
78
+ #[serde(default, skip_serializing_if = "Option::is_none")]
79
+ pub render_y: Option<f32>,
80
+ #[serde(default, skip_serializing_if = "Option::is_none")]
81
+ pub render_width: Option<f32>,
82
+ #[serde(default, skip_serializing_if = "Option::is_none")]
83
+ pub render_height: Option<f32>,
84
  }
85
 
86
  impl Default for TextBlock {
 
105
  font_prediction: None,
106
  rendered: None,
107
  lock_layout_box: false,
108
+ render_x: None,
109
+ render_y: None,
110
+ render_width: None,
111
+ render_height: None,
112
  }
113
  }
114
  }
koharu-core/src/protocol.rs CHANGED
@@ -57,6 +57,15 @@ pub struct TextBlockDetail {
57
  /// Blob hash for the rendered text block sprite.
58
  #[serde(skip_serializing_if = "Option::is_none")]
59
  pub rendered: Option<String>,
 
 
 
 
 
 
 
 
 
60
  }
61
 
62
  impl From<&TextBlock> for TextBlockDetail {
@@ -80,6 +89,10 @@ impl From<&TextBlock> for TextBlockDetail {
80
  style: block.style.clone(),
81
  font_prediction: block.font_prediction.clone(),
82
  rendered: block.rendered.as_ref().map(|r| r.hash().to_string()),
 
 
 
 
83
  }
84
  }
85
  }
 
57
  /// Blob hash for the rendered text block sprite.
58
  #[serde(skip_serializing_if = "Option::is_none")]
59
  pub rendered: Option<String>,
60
+ /// Actual render area position/size (when bubble expansion is used).
61
+ #[serde(skip_serializing_if = "Option::is_none")]
62
+ pub render_x: Option<f32>,
63
+ #[serde(skip_serializing_if = "Option::is_none")]
64
+ pub render_y: Option<f32>,
65
+ #[serde(skip_serializing_if = "Option::is_none")]
66
+ pub render_width: Option<f32>,
67
+ #[serde(skip_serializing_if = "Option::is_none")]
68
+ pub render_height: Option<f32>,
69
  }
70
 
71
  impl From<&TextBlock> for TextBlockDetail {
 
89
  style: block.style.clone(),
90
  font_prediction: block.font_prediction.clone(),
91
  rendered: block.rendered.as_ref().map(|r| r.hash().to_string()),
92
+ render_x: block.render_x,
93
+ render_y: block.render_y,
94
+ render_width: block.render_width,
95
+ render_height: block.render_height,
96
  }
97
  }
98
  }
koharu-llm/src/prompt.rs CHANGED
@@ -46,7 +46,7 @@ pub struct PromptRenderer {
46
  eos_token: String,
47
  }
48
 
49
- const BLOCK_TAG_INSTRUCTIONS: &str = "If the input contains <block id=\"N\">...</block>, translate only the text inside each block. Keep every block tag exactly unchanged, including ids, order, and block count. Do not merge blocks, split blocks, or add any text outside the blocks.";
50
 
51
  pub fn system_prompt(target_language: Language) -> String {
52
  format!(
@@ -132,8 +132,8 @@ mod tests {
132
  fn system_prompt_mentions_target_language_and_block_rules() {
133
  let prompt = system_prompt(Language::Korean);
134
  assert!(prompt.contains("natural Korean"));
135
- assert!(prompt.contains("<block id=\"N\">...</block>"));
136
- assert!(prompt.contains("Do not merge blocks"));
137
  }
138
 
139
  #[test]
 
46
  eos_token: String,
47
  }
48
 
49
+ const BLOCK_TAG_INSTRUCTIONS: &str = "The input uses numbered tags like <|1|>, <|2|>, etc. to mark each text block. Translate only the text after each tag. Keep every tag exactly unchanged, including numbers and order. Output the same tags followed by the translated text. Do not merge, split, or reorder blocks.";
50
 
51
  pub fn system_prompt(target_language: Language) -> String {
52
  format!(
 
132
  fn system_prompt_mentions_target_language_and_block_rules() {
133
  let prompt = system_prompt(Language::Korean);
134
  assert!(prompt.contains("natural Korean"));
135
+ assert!(prompt.contains("<|1|>, <|2|>"));
136
+ assert!(prompt.contains("Do not merge"));
137
  }
138
 
139
  #[test]
koharu-ml/src/comic_text_detector/mod.rs CHANGED
@@ -3,11 +3,13 @@ mod postprocess;
3
  mod unet;
4
  mod yolo_v5;
5
 
6
- use std::{cmp, time::Instant};
7
 
8
- use anyhow::Context;
9
  use candle_core::{DType, Device, IndexOp, Tensor};
10
- use image::{DynamicImage, GenericImageView, GrayImage, RgbImage, imageops};
 
 
11
  use koharu_runtime::RuntimeManager;
12
  use tracing::instrument;
13
 
@@ -18,46 +20,34 @@ pub use postprocess::{
18
  refine_segmentation_mask,
19
  };
20
 
21
- const GPU_DETECT_SIZE: u32 = 1280;
22
- const CPU_DETECT_SIZE: u32 = 640;
23
- const DET_REARRANGE_MAX_BATCHES: usize = 4;
24
- const DET_REARRANGE_DOWNSCALE_THRESHOLD: f32 = 2.5;
25
- const DET_REARRANGE_ASPECT_THRESHOLD: f32 = 3.0;
26
-
27
- struct StitchBuffers<'a> {
28
- shrink_sum: &'a mut [f32],
29
- threshold_sum: &'a mut [f32],
30
- mask_sum: &'a mut [f32],
31
- counts: &'a mut [f32],
32
- }
33
-
34
- struct PatchPlacement {
35
- width: u32,
36
- height: u32,
37
- offset_x: u32,
38
- top: u32,
39
- actual_height: u32,
40
- }
41
-
42
  const HF_REPO: &str = "mayocream/comic-text-detector";
 
 
 
 
 
 
 
 
 
43
 
44
  koharu_runtime::declare_hf_model_package!(
45
  id: "model:comic-text-detector:yolo-v5",
46
- repo: "mayocream/comic-text-detector",
47
  file: "yolo-v5.safetensors",
48
  bootstrap: true,
49
  order: 110,
50
  );
51
  koharu_runtime::declare_hf_model_package!(
52
  id: "model:comic-text-detector:unet",
53
- repo: "mayocream/comic-text-detector",
54
  file: "unet.safetensors",
55
  bootstrap: true,
56
  order: 111,
57
  );
58
  koharu_runtime::declare_hf_model_package!(
59
  id: "model:comic-text-detector:dbnet",
60
- repo: "mayocream/comic-text-detector",
61
  file: "dbnet.safetensors",
62
  bootstrap: true,
63
  order: 112,
@@ -122,365 +112,90 @@ impl ComicTextDetector {
122
 
123
  #[instrument(level = "debug", skip_all)]
124
  pub fn inference(&self, image: &DynamicImage) -> anyhow::Result<ComicTextDetection> {
125
- let detect_size = self.detect_size();
126
- let maps = if let Some(maps) =
127
- self.try_rearranged_maps(image, detect_size, DET_REARRANGE_MAX_BATCHES)?
128
- {
129
- maps
130
- } else {
131
- let original_dimensions = image.dimensions();
132
- let (image_tensor, resized_dimensions) = preprocess(image, &self.device, detect_size)?;
133
- let (mask, shrink_threshold) = self.forward(&image_tensor)?;
134
- postprocess_maps(
135
- &mask,
136
- &shrink_threshold,
137
- original_dimensions,
138
- resized_dimensions,
139
- )?
140
- };
 
 
 
 
 
141
 
142
- postprocess::build_detection(image, maps)
 
 
 
 
 
 
143
  }
144
 
145
  #[instrument(level = "debug", skip_all)]
146
  pub fn inference_segmentation(&self, image: &DynamicImage) -> anyhow::Result<GrayImage> {
147
- let started = Instant::now();
148
- let detect_size = self.detect_size();
149
- let (mask_map, rearranged) = if let Some(mask_map) =
150
- self.try_rearranged_mask_map(image, detect_size, DET_REARRANGE_MAX_BATCHES)?
151
- {
152
- (mask_map, true)
153
- } else {
154
- let original_dimensions = image.dimensions();
155
- let (image_tensor, resized_dimensions) = preprocess(image, &self.device, detect_size)?;
156
- let mask = self.forward_mask(&image_tensor)?;
157
- (
158
- postprocess_mask(&mask, original_dimensions, resized_dimensions)?,
159
- false,
160
- )
161
- };
162
-
163
- tracing::info!(
164
- width = image.width(),
165
- height = image.height(),
166
- rearranged,
167
- total_ms = started.elapsed().as_millis(),
168
- "comic text detector segmentation timings"
169
- );
170
-
171
- Ok(mask_map)
172
  }
173
 
174
  #[instrument(level = "debug", skip_all)]
175
- fn forward(&self, image: &Tensor) -> anyhow::Result<(Tensor, Tensor)> {
176
- let (mask, features) = self.forward_yolo_unet(image)?;
 
 
 
 
 
 
 
177
  let dbnet = self
178
  .dbnet
179
  .as_ref()
180
  .context("DBNet not loaded; use ComicTextDetector::load for full detection")?;
181
  let shrink_thresh = dbnet.forward(&features[0], &features[1], &features[2])?;
182
 
183
- Ok((mask, shrink_thresh))
184
  }
185
 
186
  #[instrument(level = "debug", skip_all)]
187
  fn forward_mask(&self, image: &Tensor) -> anyhow::Result<Tensor> {
188
- let (mask, _features) = self.forward_yolo_unet(image)?;
189
- Ok(mask)
190
- }
191
-
192
- #[instrument(level = "debug", skip_all)]
193
- fn forward_yolo_unet(&self, image: &Tensor) -> anyhow::Result<(Tensor, [Tensor; 3])> {
194
  let (_, features) = self.yolo.forward(image)?;
195
- let (mask, features) = self.unet.forward(
196
  &features[0],
197
  &features[1],
198
  &features[2],
199
  &features[3],
200
  &features[4],
201
  )?;
202
- Ok((mask, features))
203
- }
204
-
205
- fn detect_size(&self) -> u32 {
206
- match self.device {
207
- Device::Cpu => CPU_DETECT_SIZE,
208
- _ => GPU_DETECT_SIZE,
209
- }
210
- }
211
-
212
- fn try_rearranged_maps(
213
- &self,
214
- image: &DynamicImage,
215
- detect_size: u32,
216
- max_batch_size: usize,
217
- ) -> anyhow::Result<Option<postprocess::DetectionMaps>> {
218
- let Some(layout) = build_rearranged_layout(image, detect_size) else {
219
- return Ok(None);
220
- };
221
- let RearrangedLayout {
222
- transpose,
223
- width,
224
- height,
225
- pw_num,
226
- metadata,
227
- composites,
228
- composite_size,
229
- } = layout;
230
-
231
- let pixel_count = (width * height) as usize;
232
- let mut shrink_sum = vec![0.0f32; pixel_count];
233
- let mut threshold_sum = vec![0.0f32; pixel_count];
234
- let mut mask_sum = vec![0.0f32; pixel_count];
235
- let mut counts = vec![0.0f32; pixel_count];
236
-
237
- for batch_start in (0..composites.len()).step_by(max_batch_size.max(1)) {
238
- let batch_end = (batch_start + max_batch_size).min(composites.len());
239
- let mut tensors = Vec::with_capacity(batch_end - batch_start);
240
- for composite in &composites[batch_start..batch_end] {
241
- tensors.push(preprocess_rgb_image(composite, &self.device, detect_size)?);
242
- }
243
- let refs: Vec<&Tensor> = tensors.iter().collect();
244
- let batch = Tensor::cat(&refs, 0)?;
245
- let (mask_batch, shrink_threshold_batch) = self.forward(&batch)?;
246
-
247
- for batch_index in 0..(batch_end - batch_start) {
248
- let mask_map = tensor_channel_to_score_map_resized(
249
- &mask_batch.i((batch_index, 0))?,
250
- composite_size,
251
- composite_size,
252
- )?;
253
- let shrink_map = tensor_channel_to_score_map_resized(
254
- &shrink_threshold_batch.i((batch_index, 0))?,
255
- composite_size,
256
- composite_size,
257
- )?;
258
- let threshold_map = tensor_channel_to_score_map_resized(
259
- &shrink_threshold_batch.i((batch_index, 1))?,
260
- composite_size,
261
- composite_size,
262
- )?;
263
-
264
- for slot in 0..pw_num as usize {
265
- let patch_index = (batch_start + batch_index) * pw_num as usize + slot;
266
- if patch_index >= metadata.len() {
267
- break;
268
- }
269
- let (top, actual_height) = metadata[patch_index];
270
- if actual_height == 0 {
271
- continue;
272
- }
273
-
274
- let offset_x = slot as u32 * width;
275
- stitch_patch(
276
- StitchBuffers {
277
- shrink_sum: &mut shrink_sum,
278
- threshold_sum: &mut threshold_sum,
279
- mask_sum: &mut mask_sum,
280
- counts: &mut counts,
281
- },
282
- &shrink_map,
283
- &threshold_map,
284
- &mask_map,
285
- PatchPlacement {
286
- width,
287
- height,
288
- offset_x,
289
- top,
290
- actual_height,
291
- },
292
- );
293
- }
294
- }
295
- }
296
-
297
- let mut raw_shrink_map = accumulate_to_score_map(width, height, &shrink_sum, &counts);
298
- let mut raw_threshold_map = accumulate_to_score_map(width, height, &threshold_sum, &counts);
299
- let mut mask_map = accumulate_to_gray_image(width, height, &mask_sum, &counts)?;
300
-
301
- if transpose {
302
- raw_shrink_map = transpose_score_map(&raw_shrink_map);
303
- raw_threshold_map = transpose_score_map(&raw_threshold_map);
304
- mask_map = transpose_gray_image(&mask_map);
305
- }
306
-
307
- Ok(Some(postprocess::DetectionMaps {
308
- shrink_map: score_map_to_gray_image(&raw_shrink_map)?,
309
- threshold_map: score_map_to_gray_image(&raw_threshold_map)?,
310
- raw_shrink_map,
311
- raw_threshold_map,
312
- mask_map,
313
- }))
314
- }
315
-
316
- fn try_rearranged_mask_map(
317
- &self,
318
- image: &DynamicImage,
319
- detect_size: u32,
320
- max_batch_size: usize,
321
- ) -> anyhow::Result<Option<GrayImage>> {
322
- let Some(layout) = build_rearranged_layout(image, detect_size) else {
323
- return Ok(None);
324
- };
325
- let RearrangedLayout {
326
- transpose,
327
- width,
328
- height,
329
- pw_num,
330
- metadata,
331
- composites,
332
- composite_size,
333
- } = layout;
334
-
335
- let pixel_count = (width * height) as usize;
336
- let mut mask_sum = vec![0.0f32; pixel_count];
337
- let mut counts = vec![0.0f32; pixel_count];
338
-
339
- for batch_start in (0..composites.len()).step_by(max_batch_size.max(1)) {
340
- let batch_end = (batch_start + max_batch_size).min(composites.len());
341
- let mut tensors = Vec::with_capacity(batch_end - batch_start);
342
- for composite in &composites[batch_start..batch_end] {
343
- tensors.push(preprocess_rgb_image(composite, &self.device, detect_size)?);
344
- }
345
- let refs: Vec<&Tensor> = tensors.iter().collect();
346
- let batch = Tensor::cat(&refs, 0)?;
347
- let mask_batch = self.forward_mask(&batch)?;
348
-
349
- for batch_index in 0..(batch_end - batch_start) {
350
- let mask_map = tensor_channel_to_score_map_resized(
351
- &mask_batch.i((batch_index, 0))?,
352
- composite_size,
353
- composite_size,
354
- )?;
355
-
356
- for slot in 0..pw_num as usize {
357
- let patch_index = (batch_start + batch_index) * pw_num as usize + slot;
358
- if patch_index >= metadata.len() {
359
- break;
360
- }
361
- let (top, actual_height) = metadata[patch_index];
362
- if actual_height == 0 {
363
- continue;
364
- }
365
-
366
- let offset_x = slot as u32 * width;
367
- stitch_mask_patch(
368
- &mut mask_sum,
369
- &mut counts,
370
- &mask_map,
371
- PatchPlacement {
372
- width,
373
- height,
374
- offset_x,
375
- top,
376
- actual_height,
377
- },
378
- );
379
- }
380
- }
381
- }
382
-
383
- let mut mask_map = accumulate_to_gray_image(width, height, &mask_sum, &counts)?;
384
- if transpose {
385
- mask_map = transpose_gray_image(&mask_map);
386
- }
387
- Ok(Some(mask_map))
388
- }
389
- }
390
-
391
- struct RearrangedLayout {
392
- transpose: bool,
393
- width: u32,
394
- height: u32,
395
- pw_num: u32,
396
- metadata: Vec<(u32, u32)>,
397
- composites: Vec<RgbImage>,
398
- composite_size: u32,
399
- }
400
-
401
- fn build_rearranged_layout(image: &DynamicImage, detect_size: u32) -> Option<RearrangedLayout> {
402
- let mut working = image.to_rgb8();
403
- let mut transpose = false;
404
- let (mut height, mut width) = working.dimensions();
405
- if height < width {
406
- transpose = true;
407
- working = transpose_rgb_image(&working);
408
- (width, height) = working.dimensions();
409
- }
410
-
411
- let aspect_ratio = height as f32 / width as f32;
412
- let down_scale_ratio = height as f32 / detect_size as f32;
413
- let require_rearrange = down_scale_ratio > DET_REARRANGE_DOWNSCALE_THRESHOLD
414
- && aspect_ratio > DET_REARRANGE_ASPECT_THRESHOLD;
415
- if !require_rearrange {
416
- return None;
417
- }
418
-
419
- let pw_num = (((2 * detect_size) as f32 / width as f32).floor() as u32).max(2);
420
- let patch_height = pw_num * width;
421
- let patch_count = height.div_ceil(patch_height);
422
- let patch_step = if patch_count > 1 {
423
- (height - patch_height) / (patch_count - 1)
424
- } else {
425
- 0
426
- };
427
-
428
- let mut patches = Vec::new();
429
- let mut metadata = Vec::new();
430
- for index in 0..patch_count {
431
- let top = index * patch_step;
432
- let bottom = (top + patch_height).min(height);
433
- let actual_height = bottom.saturating_sub(top);
434
- let crop = imageops::crop_imm(&working, 0, top, width, actual_height).to_image();
435
- let mut padded = RgbImage::from_pixel(width, patch_height, image::Rgb([0, 0, 0]));
436
- imageops::replace(&mut padded, &crop, 0, 0);
437
- patches.push(padded);
438
- metadata.push((top, actual_height));
439
- }
440
-
441
- let composites_per_batch = (patch_count as usize).div_ceil(pw_num as usize);
442
- let total_slots = composites_per_batch * pw_num as usize;
443
- while patches.len() < total_slots {
444
- patches.push(RgbImage::from_pixel(
445
- width,
446
- patch_height,
447
- image::Rgb([0, 0, 0]),
448
- ));
449
- metadata.push((0, 0));
450
- }
451
-
452
- let composite_size = patch_height;
453
- let mut composites = Vec::new();
454
- for chunk in patches.chunks(pw_num as usize) {
455
- let mut composite =
456
- RgbImage::from_pixel(composite_size, composite_size, image::Rgb([0, 0, 0]));
457
- for (slot, patch) in chunk.iter().enumerate() {
458
- imageops::replace(&mut composite, patch, (slot as u32 * width) as i64, 0);
459
- }
460
- composites.push(composite);
461
  }
462
-
463
- Some(RearrangedLayout {
464
- transpose,
465
- width,
466
- height,
467
- pw_num,
468
- metadata,
469
- composites,
470
- composite_size,
471
- })
472
  }
473
 
474
  #[instrument(level = "debug", skip_all)]
475
- fn preprocess(
476
- image: &DynamicImage,
477
- device: &Device,
478
- image_size: u32,
479
- ) -> anyhow::Result<(Tensor, (u32, u32))> {
480
  let (orig_w, orig_h) = image.dimensions();
481
- let scale = (image_size as f32 / orig_w as f32).min(image_size as f32 / orig_h as f32);
482
- let width = ((orig_w as f32 * scale).round() as u32).clamp(1, image_size);
483
- let height = ((orig_h as f32 * scale).round() as u32).clamp(1, image_size);
 
 
 
 
 
 
484
  let (w, h) = (width as usize, height as usize);
485
  let tensor = (Tensor::from_vec(
486
  image.to_rgb8().into_raw(),
@@ -497,72 +212,111 @@ fn preprocess(
497
  Ok((tensor, (width, height)))
498
  }
499
 
500
- fn preprocess_rgb_image(
501
- image: &RgbImage,
502
- device: &Device,
503
- image_size: u32,
504
- ) -> anyhow::Result<Tensor> {
505
- let resized = if image.width() == image_size && image.height() == image_size {
506
- image.clone()
507
- } else {
508
- imageops::resize(
509
- image,
510
- image_size,
511
- image_size,
512
- imageops::FilterType::Triangle,
513
- )
514
- };
515
 
516
- Ok((Tensor::from_vec(
517
- resized.into_raw(),
518
- (1, image_size as usize, image_size as usize, 3),
519
- device,
520
- )?
521
- .permute((0, 3, 1, 2))?
522
- .to_dtype(DType::F32)?
523
- * (1. / 255.))?)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
524
  }
525
 
526
- fn postprocess_maps(
 
527
  mask: &Tensor,
528
  shrink_thresh: &Tensor,
529
  original_dimensions: (u32, u32),
530
  resized_dimensions: (u32, u32),
531
- ) -> anyhow::Result<postprocess::DetectionMaps> {
532
- let shrink = shrink_thresh.i((0, 0))?;
533
- let threshold = shrink_thresh.i((0, 1))?;
534
- let unet_mask = mask.i((0, 0))?;
 
535
 
536
- let (db_h, db_w) = shrink.dims2()?;
537
- let (mask_h, mask_w) = unet_mask.dims2()?;
538
- let h = cmp::min(db_h, mask_h);
539
- let w = cmp::min(db_w, mask_w);
540
- let valid_h = h.min(resized_dimensions.1 as usize);
541
- let valid_w = w.min(resized_dimensions.0 as usize);
542
 
543
- let shrink = shrink.narrow(0, 0, valid_h)?.narrow(1, 0, valid_w)?;
544
- let threshold = threshold.narrow(0, 0, valid_h)?.narrow(1, 0, valid_w)?;
545
- let unet_mask = unet_mask.narrow(0, 0, valid_h)?.narrow(1, 0, valid_w)?;
546
 
547
- let raw_shrink_map = tensor_channel_to_score_map_exact(&shrink)?;
548
- let raw_threshold_map = tensor_channel_to_score_map_exact(&threshold)?;
549
- let shrink_map =
550
- tensor_channel_to_gray_resized(&shrink, original_dimensions.0, original_dimensions.1)?;
551
- let threshold_map =
552
- tensor_channel_to_gray_resized(&threshold, original_dimensions.0, original_dimensions.1)?;
553
- let mask_map =
554
- tensor_channel_to_gray_resized(&unet_mask, original_dimensions.0, original_dimensions.1)?;
555
-
556
- Ok(postprocess::DetectionMaps {
557
- raw_shrink_map,
558
- raw_threshold_map,
559
- shrink_map,
560
- threshold_map,
561
- mask_map,
562
- })
 
 
 
 
 
 
 
 
 
 
 
 
563
  }
564
 
565
- fn postprocess_mask(
566
  mask: &Tensor,
567
  original_dimensions: (u32, u32),
568
  resized_dimensions: (u32, u32),
@@ -580,88 +334,65 @@ fn tensor_channel_to_gray_resized(
580
  width: u32,
581
  height: u32,
582
  ) -> anyhow::Result<GrayImage> {
583
- score_map_to_gray_image(&tensor_channel_to_score_map_resized(tensor, width, height)?)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
584
  }
585
 
586
- fn tensor_channel_to_score_map_exact(tensor: &Tensor) -> anyhow::Result<postprocess::ScoreMap> {
587
- let (height, width) = tensor.dims2()?;
588
- let values: Vec<f32> = tensor.flatten_all()?.to_vec1()?;
589
- Ok(postprocess::ScoreMap {
590
- width: width as u32,
591
- height: height as u32,
592
- values,
593
- })
 
 
 
 
 
 
 
 
 
 
594
  }
595
 
596
- fn tensor_channel_to_score_map_resized(
597
- tensor: &Tensor,
598
- width: u32,
599
- height: u32,
600
- ) -> anyhow::Result<postprocess::ScoreMap> {
601
- let resized = tensor
602
- .unsqueeze(0)?
603
- .unsqueeze(0)?
604
- .interpolate2d(height as usize, width as usize)?
605
- .squeeze(0)?
606
- .squeeze(0)?;
607
- let values: Vec<f32> = resized.flatten_all()?.to_vec1()?;
608
- Ok(postprocess::ScoreMap {
609
- width,
610
- height,
611
- values,
612
- })
613
  }
614
 
615
- fn stitch_patch(
616
- buffers: StitchBuffers<'_>,
617
- shrink_map: &postprocess::ScoreMap,
618
- threshold_map: &postprocess::ScoreMap,
619
- mask_map: &postprocess::ScoreMap,
620
- placement: PatchPlacement,
621
- ) {
622
- let PatchPlacement {
623
- width,
624
- height,
625
- offset_x,
626
- top,
627
- actual_height,
628
- } = placement;
629
- for y in 0..actual_height.min(height.saturating_sub(top)) {
630
- for x in 0..width {
631
- let global_index = ((top + y) * width + x) as usize;
632
- let source_x = offset_x + x;
633
- buffers.shrink_sum[global_index] += shrink_map.get(source_x, y);
634
- buffers.threshold_sum[global_index] += threshold_map.get(source_x, y);
635
- buffers.mask_sum[global_index] += mask_map.get(source_x, y);
636
- buffers.counts[global_index] += 1.0;
637
- }
638
- }
639
  }
640
 
641
- fn stitch_mask_patch(
642
- mask_sum: &mut [f32],
643
- counts: &mut [f32],
644
- mask_map: &postprocess::ScoreMap,
645
- placement: PatchPlacement,
646
- ) {
647
- let PatchPlacement {
648
- width,
649
- height,
650
- offset_x,
651
- top,
652
- actual_height,
653
- } = placement;
654
- for y in 0..actual_height.min(height.saturating_sub(top)) {
655
- for x in 0..width {
656
- let global_index = ((top + y) * width + x) as usize;
657
- let source_x = offset_x + x;
658
- mask_sum[global_index] += mask_map.get(source_x, y);
659
- counts[global_index] += 1.0;
660
- }
661
- }
662
  }
663
 
664
- pub async fn prefetch_segmentation(runtime: &RuntimeManager) -> anyhow::Result<()> {
665
  let downloads = runtime.downloads();
666
  downloads
667
  .huggingface_model(HF_REPO, "yolo-v5.safetensors")
@@ -669,89 +400,19 @@ pub async fn prefetch_segmentation(runtime: &RuntimeManager) -> anyhow::Result<(
669
  downloads
670
  .huggingface_model(HF_REPO, "unet.safetensors")
671
  .await?;
 
 
 
672
  Ok(())
673
  }
674
 
675
- fn accumulate_to_score_map(
676
- width: u32,
677
- height: u32,
678
- values: &[f32],
679
- counts: &[f32],
680
- ) -> postprocess::ScoreMap {
681
- let values = values
682
- .iter()
683
- .zip(counts.iter())
684
- .map(|(value, count)| {
685
- if *count <= 0.0 {
686
- 0.0
687
- } else {
688
- (value / count).clamp(0.0, 1.0)
689
- }
690
- })
691
- .collect();
692
- postprocess::ScoreMap {
693
- width,
694
- height,
695
- values,
696
- }
697
- }
698
-
699
- fn accumulate_to_gray_image(
700
- width: u32,
701
- height: u32,
702
- values: &[f32],
703
- counts: &[f32],
704
- ) -> anyhow::Result<GrayImage> {
705
- score_map_to_gray_image(&accumulate_to_score_map(width, height, values, counts))
706
- }
707
-
708
- fn transpose_rgb_image(image: &RgbImage) -> RgbImage {
709
- RgbImage::from_fn(image.height(), image.width(), |x, y| *image.get_pixel(y, x))
710
- }
711
-
712
- fn transpose_gray_image(image: &GrayImage) -> GrayImage {
713
- GrayImage::from_fn(image.height(), image.width(), |x, y| *image.get_pixel(y, x))
714
- }
715
-
716
- fn transpose_score_map(score_map: &postprocess::ScoreMap) -> postprocess::ScoreMap {
717
- let mut values = vec![0.0f32; (score_map.width * score_map.height) as usize];
718
- for y in 0..score_map.height {
719
- for x in 0..score_map.width {
720
- values[(x * score_map.height + y) as usize] = score_map.get(x, y);
721
- }
722
- }
723
- postprocess::ScoreMap {
724
- width: score_map.height,
725
- height: score_map.width,
726
- values,
727
- }
728
- }
729
-
730
- fn score_map_to_gray_image(score_map: &postprocess::ScoreMap) -> anyhow::Result<GrayImage> {
731
- let bytes: Vec<u8> = score_map
732
- .values
733
- .iter()
734
- .copied()
735
- .map(float_to_byte)
736
- .collect();
737
- GrayImage::from_raw(score_map.width, score_map.height, bytes)
738
- .context("failed to build CTD map image")
739
- }
740
-
741
- fn float_to_byte(value: f32) -> u8 {
742
- (value.clamp(0.0, 1.0) * 255.0).round() as u8
743
- }
744
-
745
- #[cfg(test)]
746
- mod tests {
747
- use super::*;
748
- use image::Luma;
749
-
750
- #[test]
751
- fn transpose_helpers_round_trip() {
752
- let image = GrayImage::from_fn(3, 5, |x, y| Luma([(x + y * 3) as u8]));
753
- let transposed = transpose_gray_image(&image);
754
- let round_trip = transpose_gray_image(&transposed);
755
- assert_eq!(round_trip, image);
756
- }
757
  }
 
3
  mod unet;
4
  mod yolo_v5;
5
 
6
+ use std::cmp;
7
 
8
+ use anyhow::{Context, bail};
9
  use candle_core::{DType, Device, IndexOp, Tensor};
10
+ use candle_transformers::object_detection::{Bbox, non_maximum_suppression};
11
+ use image::{DynamicImage, GenericImageView, GrayImage};
12
+ use koharu_core::TextBlock;
13
  use koharu_runtime::RuntimeManager;
14
  use tracing::instrument;
15
 
 
20
  refine_segmentation_mask,
21
  };
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  const HF_REPO: &str = "mayocream/comic-text-detector";
24
+ const CONFIDENCE_THRESHOLD: f32 = 0.4;
25
+ const NMS_THRESHOLD: f32 = 0.35;
26
+ const DBNET_BINARIZE_K: f64 = 50.0;
27
+ const BINARY_THRESHOLD: u8 = 60;
28
+ const DILATION_RADIUS: u32 = 3;
29
+ const HOLE_CLOSE_RADIUS: u32 = 10;
30
+ const BBOX_DILATION: f32 = 1.0;
31
+ const GPU_DETECT_SIZE: u32 = 1024;
32
+ const CPU_DETECT_SIZE: u32 = 640;
33
 
34
  koharu_runtime::declare_hf_model_package!(
35
  id: "model:comic-text-detector:yolo-v5",
36
+ repo: HF_REPO,
37
  file: "yolo-v5.safetensors",
38
  bootstrap: true,
39
  order: 110,
40
  );
41
  koharu_runtime::declare_hf_model_package!(
42
  id: "model:comic-text-detector:unet",
43
+ repo: HF_REPO,
44
  file: "unet.safetensors",
45
  bootstrap: true,
46
  order: 111,
47
  );
48
  koharu_runtime::declare_hf_model_package!(
49
  id: "model:comic-text-detector:dbnet",
50
+ repo: HF_REPO,
51
  file: "dbnet.safetensors",
52
  bootstrap: true,
53
  order: 112,
 
112
 
113
  #[instrument(level = "debug", skip_all)]
114
  pub fn inference(&self, image: &DynamicImage) -> anyhow::Result<ComicTextDetection> {
115
+ let original_dimensions = image.dimensions();
116
+ let (image_tensor, resized_dimensions) = preprocess(image, &self.device)?;
117
+ let (predictions, mask, shrink_threshold) = self.forward(&image_tensor)?;
118
+
119
+ let bboxes = postprocess_yolo(&predictions, original_dimensions, resized_dimensions)?;
120
+ let shrink_map = tensor_channel_to_gray_resized(
121
+ &shrink_threshold.i((0, 0))?,
122
+ original_dimensions.0,
123
+ original_dimensions.1,
124
+ )?;
125
+ let threshold_map = tensor_channel_to_gray_resized(
126
+ &shrink_threshold.i((0, 1))?,
127
+ original_dimensions.0,
128
+ original_dimensions.1,
129
+ )?;
130
+ let mask = postprocess_mask(
131
+ &mask,
132
+ &shrink_threshold,
133
+ original_dimensions,
134
+ resized_dimensions,
135
+ )?;
136
 
137
+ Ok(ComicTextDetection {
138
+ shrink_map,
139
+ threshold_map,
140
+ line_polygons: Vec::new(),
141
+ text_blocks: bboxes_to_text_blocks(bboxes),
142
+ mask,
143
+ })
144
  }
145
 
146
  #[instrument(level = "debug", skip_all)]
147
  pub fn inference_segmentation(&self, image: &DynamicImage) -> anyhow::Result<GrayImage> {
148
+ let original_dimensions = image.dimensions();
149
+ let (image_tensor, resized_dimensions) = preprocess(image, &self.device)?;
150
+ let mask = self.forward_mask(&image_tensor)?;
151
+ postprocess_unet_mask(&mask, original_dimensions, resized_dimensions)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  }
153
 
154
  #[instrument(level = "debug", skip_all)]
155
+ fn forward(&self, image: &Tensor) -> anyhow::Result<(Tensor, Tensor, Tensor)> {
156
+ let (predictions, features) = self.yolo.forward(image)?;
157
+ let (mask, features) = self.unet.forward(
158
+ &features[0],
159
+ &features[1],
160
+ &features[2],
161
+ &features[3],
162
+ &features[4],
163
+ )?;
164
  let dbnet = self
165
  .dbnet
166
  .as_ref()
167
  .context("DBNet not loaded; use ComicTextDetector::load for full detection")?;
168
  let shrink_thresh = dbnet.forward(&features[0], &features[1], &features[2])?;
169
 
170
+ Ok((predictions, mask, shrink_thresh))
171
  }
172
 
173
  #[instrument(level = "debug", skip_all)]
174
  fn forward_mask(&self, image: &Tensor) -> anyhow::Result<Tensor> {
 
 
 
 
 
 
175
  let (_, features) = self.yolo.forward(image)?;
176
+ let (mask, _) = self.unet.forward(
177
  &features[0],
178
  &features[1],
179
  &features[2],
180
  &features[3],
181
  &features[4],
182
  )?;
183
+ Ok(mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  }
 
 
 
 
 
 
 
 
 
 
185
  }
186
 
187
  #[instrument(level = "debug", skip_all)]
188
+ fn preprocess(image: &DynamicImage, device: &Device) -> anyhow::Result<(Tensor, (u32, u32))> {
 
 
 
 
189
  let (orig_w, orig_h) = image.dimensions();
190
+ let image_size = match device {
191
+ Device::Cpu => CPU_DETECT_SIZE,
192
+ _ => GPU_DETECT_SIZE,
193
+ };
194
+ let (width, height) = if orig_w >= orig_h {
195
+ (image_size, image_size * orig_h / orig_w)
196
+ } else {
197
+ (image_size * orig_w / orig_h, image_size)
198
+ };
199
  let (w, h) = (width as usize, height as usize);
200
  let tensor = (Tensor::from_vec(
201
  image.to_rgb8().into_raw(),
 
212
  Ok((tensor, (width, height)))
213
  }
214
 
215
+ #[instrument(level = "debug", skip(predictions))]
216
+ fn postprocess_yolo(
217
+ predictions: &Tensor,
218
+ original_dimensions: (u32, u32),
219
+ resized_dimensions: (u32, u32),
220
+ ) -> anyhow::Result<Vec<Bbox<usize>>> {
221
+ let predictions = predictions.squeeze(0)?;
222
+ let (_, num_outputs) = predictions.dims2()?;
223
+ if num_outputs < 6 {
224
+ bail!("invalid prediction shape: expected at least 6 outputs, got {num_outputs}");
225
+ }
 
 
 
 
226
 
227
+ let num_classes = num_outputs - 5;
228
+ let (orig_w, orig_h) = original_dimensions;
229
+ let (resized_w, resized_h) = resized_dimensions;
230
+ let w_ratio = orig_w as f32 / resized_w as f32;
231
+ let h_ratio = orig_h as f32 / resized_h as f32;
232
+
233
+ let mut bboxes: Vec<Vec<Bbox<usize>>> = (0..num_classes).map(|_| Vec::new()).collect();
234
+ let predictions: Vec<Vec<f32>> = predictions.to_vec2()?;
235
+ for pred in predictions {
236
+ let (class_index, confidence) = {
237
+ let (cls_idx, cls_score) = pred[5..]
238
+ .iter()
239
+ .copied()
240
+ .enumerate()
241
+ .max_by(|a, b| a.1.total_cmp(&b.1))
242
+ .unwrap_or((0, 0.0));
243
+ (cls_idx, pred[4] * cls_score)
244
+ };
245
+ if confidence < CONFIDENCE_THRESHOLD {
246
+ continue;
247
+ }
248
+
249
+ let xmin = ((pred[0] - pred[2] / 2.) * w_ratio - BBOX_DILATION).clamp(0., orig_w as f32);
250
+ let xmax = ((pred[0] + pred[2] / 2.) * w_ratio + BBOX_DILATION).clamp(0., orig_w as f32);
251
+ let ymin = ((pred[1] - pred[3] / 2.) * h_ratio - BBOX_DILATION).clamp(0., orig_h as f32);
252
+ let ymax = ((pred[1] + pred[3] / 2.) * h_ratio + BBOX_DILATION).clamp(0., orig_h as f32);
253
+
254
+ bboxes[class_index].push(Bbox {
255
+ xmin,
256
+ xmax,
257
+ ymin,
258
+ ymax,
259
+ confidence,
260
+ data: class_index,
261
+ });
262
+ }
263
+
264
+ non_maximum_suppression(&mut bboxes, NMS_THRESHOLD);
265
+ Ok(bboxes.into_iter().flatten().collect())
266
  }
267
 
268
+ #[instrument(level = "debug", skip(mask, shrink_thresh))]
269
+ fn postprocess_mask(
270
  mask: &Tensor,
271
  shrink_thresh: &Tensor,
272
  original_dimensions: (u32, u32),
273
  resized_dimensions: (u32, u32),
274
+ ) -> anyhow::Result<GrayImage> {
275
+ let shrink_and_thresh = shrink_thresh.squeeze(0)?;
276
+ let shrink = shrink_and_thresh.i(0)?;
277
+ let thresh = shrink_and_thresh.i(1)?;
278
+ let unet_mask = mask.squeeze(0)?;
279
 
280
+ let (_, h_db, w_db) = shrink_and_thresh.dims3()?;
281
+ let (_, h_unet, w_unet) = unet_mask.dims3()?;
282
+ let h = cmp::min(h_db, h_unet);
283
+ let w = cmp::min(w_db, w_unet);
 
 
284
 
285
+ let shrink = shrink.narrow(0, 0, h)?.narrow(1, 0, w)?;
286
+ let thresh = thresh.narrow(0, 0, h)?.narrow(1, 0, w)?;
287
+ let unet_mask = unet_mask.narrow(1, 0, h)?.narrow(2, 0, w)?.squeeze(0)?;
288
 
289
+ let prob = candle_nn::ops::sigmoid(&((&shrink - &thresh)? * DBNET_BINARIZE_K)?)?;
290
+ let fused = prob.maximum(&unet_mask)?;
291
+
292
+ let (mask_h, mask_w) = fused.dims2()?;
293
+ let valid_h = mask_h.min(resized_dimensions.1 as usize);
294
+ let valid_w = mask_w.min(resized_dimensions.0 as usize);
295
+
296
+ let fused = fused
297
+ .narrow(0, 0, valid_h)?
298
+ .narrow(1, 0, valid_w)?
299
+ .unsqueeze(0)?
300
+ .unsqueeze(0)?;
301
+
302
+ let resized = fused.interpolate2d(
303
+ original_dimensions.1 as usize,
304
+ original_dimensions.0 as usize,
305
+ )?;
306
+ let threshold = BINARY_THRESHOLD as f32 / 255.0;
307
+ let binary = resized.ge(threshold)?.to_dtype(DType::F32)?;
308
+
309
+ let closed = morph_close(&binary, HOLE_CLOSE_RADIUS as usize)?;
310
+ let dilated = dilate_tensor(&closed, DILATION_RADIUS as usize)?;
311
+ let mask = dilated.squeeze(0)?.squeeze(0)?;
312
+
313
+ let mask = (mask * 255.)?.to_dtype(DType::U8)?;
314
+ let data: Vec<u8> = mask.flatten_all()?.to_vec1()?;
315
+ GrayImage::from_raw(original_dimensions.0, original_dimensions.1, data)
316
+ .context("failed to build mask image")
317
  }
318
 
319
+ fn postprocess_unet_mask(
320
  mask: &Tensor,
321
  original_dimensions: (u32, u32),
322
  resized_dimensions: (u32, u32),
 
334
  width: u32,
335
  height: u32,
336
  ) -> anyhow::Result<GrayImage> {
337
+ let (th, tw) = tensor.dims2()?;
338
+ let values: Vec<f32> = tensor.to_device(&Device::Cpu)?.flatten_all()?.to_vec1()?;
339
+ let pixels: Vec<u8> = values
340
+ .iter()
341
+ .map(|&v| (v.clamp(0.0, 1.0) * 255.0).round() as u8)
342
+ .collect();
343
+ let small = GrayImage::from_raw(tw as u32, th as u32, pixels)
344
+ .context("failed to create gray image from tensor")?;
345
+ if tw as u32 == width && th as u32 == height {
346
+ return Ok(small);
347
+ }
348
+ Ok(image::imageops::resize(
349
+ &small,
350
+ width,
351
+ height,
352
+ image::imageops::FilterType::Nearest,
353
+ ))
354
  }
355
 
356
+ fn bboxes_to_text_blocks(mut bboxes: Vec<Bbox<usize>>) -> Vec<TextBlock> {
357
+ bboxes.sort_unstable_by(|a, b| {
358
+ let ay = a.ymin + (a.ymax - a.ymin) * 0.5;
359
+ let by = b.ymin + (b.ymax - b.ymin) * 0.5;
360
+ ay.partial_cmp(&by).unwrap_or(std::cmp::Ordering::Equal)
361
+ });
362
+
363
+ bboxes
364
+ .into_iter()
365
+ .map(|bbox| TextBlock {
366
+ x: bbox.xmin,
367
+ y: bbox.ymin,
368
+ width: bbox.xmax - bbox.xmin,
369
+ height: bbox.ymax - bbox.ymin,
370
+ confidence: bbox.confidence,
371
+ ..Default::default()
372
+ })
373
+ .collect()
374
  }
375
 
376
+ fn dilate_tensor(mask: &Tensor, radius: usize) -> anyhow::Result<Tensor> {
377
+ let kernel = 2 * radius + 1;
378
+ let padded = mask
379
+ .pad_with_zeros(2, radius, radius)?
380
+ .pad_with_zeros(3, radius, radius)?;
381
+ Ok(padded.max_pool2d_with_stride((kernel, kernel), (1, 1))?)
 
 
 
 
 
 
 
 
 
 
 
382
  }
383
 
384
+ fn erode_tensor(mask: &Tensor, radius: usize) -> anyhow::Result<Tensor> {
385
+ let inverted = (1.0 - mask)?;
386
+ let dilated = dilate_tensor(&inverted, radius)?;
387
+ Ok((1.0 - dilated)?)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
388
  }
389
 
390
+ fn morph_close(mask: &Tensor, radius: usize) -> anyhow::Result<Tensor> {
391
+ let dilated = dilate_tensor(mask, radius)?;
392
+ erode_tensor(&dilated, radius)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
393
  }
394
 
395
+ pub async fn prefetch(runtime: &RuntimeManager) -> anyhow::Result<()> {
396
  let downloads = runtime.downloads();
397
  downloads
398
  .huggingface_model(HF_REPO, "yolo-v5.safetensors")
 
400
  downloads
401
  .huggingface_model(HF_REPO, "unet.safetensors")
402
  .await?;
403
+ downloads
404
+ .huggingface_model(HF_REPO, "dbnet.safetensors")
405
+ .await?;
406
  Ok(())
407
  }
408
 
409
+ pub async fn prefetch_segmentation(runtime: &RuntimeManager) -> anyhow::Result<()> {
410
+ let downloads = runtime.downloads();
411
+ downloads
412
+ .huggingface_model(HF_REPO, "yolo-v5.safetensors")
413
+ .await?;
414
+ downloads
415
+ .huggingface_model(HF_REPO, "unet.safetensors")
416
+ .await?;
417
+ Ok(())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
418
  }
koharu-ml/src/comic_text_detector/postprocess.rs CHANGED
@@ -3,21 +3,12 @@ use image::{
3
  imageops::{self},
4
  };
5
  use imageproc::{
6
- contours::{BorderType as ContourBorderType, find_contours},
7
- contrast::otsu_level,
8
  distance_transform::Norm,
9
- drawing::draw_polygon_mut,
10
  geometric_transformations::{Interpolation, Projection, warp_into},
11
- morphology::{dilate, erode},
12
- point::Point,
13
- region_labelling::{Connectivity, connected_components},
14
  };
15
  use koharu_core::{TextBlock, TextDirection};
16
 
17
- const LINE_THRESHOLD: f32 = 0.3;
18
- const LINE_SCORE_THRESHOLD: f32 = 0.6;
19
- const MASK_SCORE_THRESHOLD: f32 = 0.1;
20
- const CTD_UNCLIP_RATIO: f32 = 1.5;
21
  const FINAL_MASK_DILATE_RADIUS: u8 = 2;
22
 
23
  pub type Quad = [[f32; 2]; 4];
@@ -31,190 +22,12 @@ pub struct ComicTextDetection {
31
  pub mask: GrayImage,
32
  }
33
 
34
- #[derive(Debug, Clone)]
35
- pub(crate) struct ScoreMap {
36
- pub width: u32,
37
- pub height: u32,
38
- pub values: Vec<f32>,
39
- }
40
-
41
- impl ScoreMap {
42
- pub(crate) fn get(&self, x: u32, y: u32) -> f32 {
43
- self.values[(y * self.width + x) as usize]
44
- }
45
- }
46
-
47
- #[derive(Debug, Clone)]
48
- pub(crate) struct DetectionMaps {
49
- pub raw_shrink_map: ScoreMap,
50
- pub raw_threshold_map: ScoreMap,
51
- pub shrink_map: GrayImage,
52
- pub threshold_map: GrayImage,
53
- pub mask_map: GrayImage,
54
- }
55
-
56
- #[derive(Debug, Clone)]
57
- struct DetectedLine {
58
- quad: Quad,
59
- vertical: bool,
60
- score: f32,
61
- }
62
-
63
- #[derive(Debug, Clone)]
64
- struct CtdBlock {
65
- bbox: [f32; 4],
66
- confidence: f32,
67
- source_language: String,
68
- source_direction: TextDirection,
69
- lines: Vec<Quad>,
70
- angle_deg: f32,
71
- detected_font_size_px: f32,
72
- distances: Vec<f32>,
73
- direction_vec: [f32; 2],
74
- direction_norm: f32,
75
- merged: bool,
76
- }
77
-
78
- #[derive(Debug, Clone)]
79
- struct Component {
80
- x: u32,
81
- y: u32,
82
- w: u32,
83
- h: u32,
84
- area: u32,
85
- pixels: Vec<(u32, u32)>,
86
- }
87
-
88
- #[derive(Debug, Clone)]
89
- struct CandidateMask {
90
- mask: GrayImage,
91
- xor_sum: u64,
92
- }
93
-
94
- impl CtdBlock {
95
- fn from_line(line: &DetectedLine) -> Self {
96
- let bbox = quad_bbox(&line.quad);
97
- Self {
98
- bbox,
99
- confidence: line.score,
100
- source_language: "unknown".to_string(),
101
- source_direction: if line.vertical {
102
- TextDirection::Vertical
103
- } else {
104
- TextDirection::Horizontal
105
- },
106
- lines: vec![line.quad],
107
- angle_deg: 0.0,
108
- detected_font_size_px: 0.0,
109
- distances: Vec::new(),
110
- direction_vec: [1.0, 0.0],
111
- direction_norm: 1.0,
112
- merged: false,
113
- }
114
- }
115
-
116
- fn center(&self) -> [f32; 2] {
117
- [
118
- (self.bbox[0] + self.bbox[2]) * 0.5,
119
- (self.bbox[1] + self.bbox[3]) * 0.5,
120
- ]
121
- }
122
-
123
- fn adjust_bbox(&mut self, with_bbox: bool) {
124
- if self.lines.is_empty() {
125
- return;
126
- }
127
-
128
- let mut min_x = f32::MAX;
129
- let mut min_y = f32::MAX;
130
- let mut max_x = f32::MIN;
131
- let mut max_y = f32::MIN;
132
- for line in &self.lines {
133
- let bbox = quad_bbox(line);
134
- min_x = min_x.min(bbox[0]);
135
- min_y = min_y.min(bbox[1]);
136
- max_x = max_x.max(bbox[2]);
137
- max_y = max_y.max(bbox[3]);
138
- }
139
-
140
- if with_bbox {
141
- self.bbox[0] = self.bbox[0].min(min_x);
142
- self.bbox[1] = self.bbox[1].min(min_y);
143
- self.bbox[2] = self.bbox[2].max(max_x);
144
- self.bbox[3] = self.bbox[3].max(max_y);
145
- } else {
146
- self.bbox = [min_x, min_y, max_x, max_y];
147
- }
148
- }
149
-
150
- fn sort_lines(&mut self) {
151
- if self.distances.len() != self.lines.len() {
152
- return;
153
- }
154
-
155
- let mut indexed: Vec<(f32, Quad)> = self
156
- .distances
157
- .iter()
158
- .copied()
159
- .zip(self.lines.iter().copied())
160
- .collect();
161
- indexed.sort_by(|a, b| a.0.total_cmp(&b.0));
162
- self.distances = indexed.iter().map(|(distance, _)| *distance).collect();
163
- self.lines = indexed.into_iter().map(|(_, line)| line).collect();
164
- }
165
- }
166
-
167
- pub fn build_detection(
168
- image: &DynamicImage,
169
- maps: DetectionMaps,
170
- ) -> anyhow::Result<ComicTextDetection> {
171
- let DetectionMaps {
172
- raw_shrink_map,
173
- raw_threshold_map,
174
- shrink_map,
175
- threshold_map,
176
- mask_map,
177
- } = maps;
178
- let scale_x = image.width() as f32 / raw_shrink_map.width.max(1) as f32;
179
- let scale_y = image.height() as f32 / raw_shrink_map.height.max(1) as f32;
180
- let detected_lines = extract_detected_lines(&raw_shrink_map, &raw_threshold_map)
181
- .into_iter()
182
- .map(|mut line| {
183
- if (scale_x - 1.0).abs() > f32::EPSILON || (scale_y - 1.0).abs() > f32::EPSILON {
184
- line.quad = scale_quad(&line.quad, scale_x, scale_y);
185
- }
186
- line
187
- })
188
- .collect::<Vec<_>>();
189
- let line_polygons = detected_lines.iter().map(|line| line.quad).collect();
190
- let text_blocks = group_output(&detected_lines, &mask_map, image.width(), image.height());
191
- let mask = refine_segmentation_mask(image, &mask_map, &text_blocks);
192
-
193
- Ok(ComicTextDetection {
194
- shrink_map,
195
- threshold_map,
196
- line_polygons,
197
- text_blocks,
198
- mask,
199
- })
200
- }
201
-
202
  pub fn refine_segmentation_mask(
203
- image: &DynamicImage,
204
  pred_mask: &GrayImage,
205
- blocks: &[TextBlock],
206
  ) -> GrayImage {
207
- let base = if blocks.is_empty() {
208
- threshold_binary(pred_mask, 60)
209
- } else {
210
- let refined = refine_mask(&image.to_rgb8(), pred_mask, blocks);
211
- if refined.pixels().any(|pixel| pixel[0] > 0) {
212
- refined
213
- } else {
214
- threshold_binary(pred_mask, 60)
215
- }
216
- };
217
-
218
  dilate(&base, Norm::L1, FINAL_MASK_DILATE_RADIUS)
219
  }
220
 
@@ -403,1541 +216,142 @@ fn maybe_expand_ctd_line(block: &TextBlock, line: &Quad) -> Quad {
403
  out
404
  }
405
 
406
- fn extract_detected_lines(shrink_map: &ScoreMap, _threshold_map: &ScoreMap) -> Vec<DetectedLine> {
407
- let binary = GrayImage::from_fn(shrink_map.width, shrink_map.height, |x, y| {
408
- let shrink_score = shrink_map.get(x, y);
409
- if shrink_score > LINE_THRESHOLD {
410
- Luma([255u8])
411
- } else {
412
- Luma([0u8])
413
- }
414
- });
415
-
416
- let contours = find_contours::<i32>(&binary);
417
- let mut lines = Vec::new();
418
- for contour in contours {
419
- if contour.border_type != ContourBorderType::Outer || contour.points.len() < 4 {
420
- continue;
421
- }
422
-
423
- let contour_points = contour
424
- .points
425
- .into_iter()
426
- .map(|point| [point.x as f32, point.y as f32])
427
- .collect::<Vec<_>>();
428
- let score = contour_score_fast(shrink_map, &contour_points);
429
- if score < LINE_SCORE_THRESHOLD {
430
- continue;
431
- }
432
-
433
- if let Some((quad, vertical)) = contour_quad(&contour_points) {
434
- let (norm_v, norm_h) = quad_axis_lengths(&quad);
435
- if norm_v.min(norm_h) < 2.0 {
436
- continue;
437
- }
438
- lines.push(DetectedLine {
439
- quad,
440
- vertical,
441
- score,
442
- });
443
- }
444
- }
445
-
446
- lines
447
- }
448
-
449
- #[cfg(test)]
450
- fn component_quad(component: &Component) -> Option<(Quad, bool)> {
451
- if component.pixels.len() < 2 {
452
- return None;
453
- }
454
-
455
- let mut mean_x = 0.0f32;
456
- let mut mean_y = 0.0f32;
457
- for (x, y) in &component.pixels {
458
- mean_x += *x as f32;
459
- mean_y += *y as f32;
460
- }
461
- let n = component.pixels.len() as f32;
462
- mean_x /= n;
463
- mean_y /= n;
464
-
465
- let mut sxx = 0.0f32;
466
- let mut syy = 0.0f32;
467
- let mut sxy = 0.0f32;
468
- for (x, y) in &component.pixels {
469
- let dx = *x as f32 - mean_x;
470
- let dy = *y as f32 - mean_y;
471
- sxx += dx * dx;
472
- syy += dy * dy;
473
- sxy += dx * dy;
474
- }
475
- let angle = 0.5 * (2.0 * sxy).atan2(sxx - syy);
476
- let ux = angle.cos();
477
- let uy = angle.sin();
478
- let vx = -uy;
479
- let vy = ux;
480
-
481
- let mut min_u = f32::MAX;
482
- let mut max_u = f32::MIN;
483
- let mut min_v = f32::MAX;
484
- let mut max_v = f32::MIN;
485
- for (x, y) in &component.pixels {
486
- let dx = *x as f32 - mean_x;
487
- let dy = *y as f32 - mean_y;
488
- let u = dx * ux + dy * uy;
489
- let v = dx * vx + dy * vy;
490
- min_u = min_u.min(u);
491
- max_u = max_u.max(u);
492
- min_v = min_v.min(v);
493
- max_v = max_v.max(v);
494
- }
495
-
496
- let width = (max_u - min_u).max(1.0);
497
- let height = (max_v - min_v).max(1.0);
498
- let perimeter = 2.0 * (width + height);
499
- let offset = if perimeter > 0.0 {
500
- (width * height * CTD_UNCLIP_RATIO) / perimeter
501
- } else {
502
- 0.0
503
- };
504
- min_u -= offset;
505
- max_u += offset;
506
- min_v -= offset;
507
- max_v += offset;
508
-
509
- let quad = [
510
- [
511
- mean_x + ux * min_u + vx * min_v,
512
- mean_y + uy * min_u + vy * min_v,
513
- ],
514
- [
515
- mean_x + ux * max_u + vx * min_v,
516
- mean_y + uy * max_u + vy * min_v,
517
- ],
518
- [
519
- mean_x + ux * max_u + vx * max_v,
520
- mean_y + uy * max_u + vy * max_v,
521
- ],
522
- [
523
- mean_x + ux * min_u + vx * max_v,
524
- mean_y + uy * min_u + vy * max_v,
525
- ],
526
- ];
527
- let (quad, vertical) = sort_quad_points(&quad);
528
- Some((quad, vertical))
529
- }
530
-
531
- fn contour_quad(points: &[[f32; 2]]) -> Option<(Quad, bool)> {
532
- let quad = minimum_area_rect(points)?;
533
- let area = polygon_area(&quad);
534
- let perimeter = polygon_perimeter(&quad);
535
- let offset = if perimeter > 0.0 {
536
- (area * CTD_UNCLIP_RATIO) / perimeter
537
- } else {
538
- 0.0
539
- };
540
- let expanded = expand_quad(&quad, offset);
541
- let (quad, vertical) = sort_quad_points(&expanded);
542
- Some((quad, vertical))
543
- }
544
-
545
- fn minimum_area_rect(points: &[[f32; 2]]) -> Option<Quad> {
546
- let hull = convex_hull(points);
547
- if hull.len() < 3 {
548
- return None;
549
- }
550
-
551
- let mut best_area = f32::MAX;
552
- let mut best_quad = None;
553
- for index in 0..hull.len() {
554
- let next = (index + 1) % hull.len();
555
- let edge = [
556
- hull[next][0] - hull[index][0],
557
- hull[next][1] - hull[index][1],
558
- ];
559
- let edge_norm = vector_norm(edge);
560
- if edge_norm <= 1e-6 {
561
- continue;
562
- }
563
-
564
- let axis_u = [edge[0] / edge_norm, edge[1] / edge_norm];
565
- let axis_v = [-axis_u[1], axis_u[0]];
566
- let mut min_u = f32::MAX;
567
- let mut max_u = f32::MIN;
568
- let mut min_v = f32::MAX;
569
- let mut max_v = f32::MIN;
570
-
571
- for point in &hull {
572
- let proj_u = dot(*point, axis_u);
573
- let proj_v = dot(*point, axis_v);
574
- min_u = min_u.min(proj_u);
575
- max_u = max_u.max(proj_u);
576
- min_v = min_v.min(proj_v);
577
- max_v = max_v.max(proj_v);
578
- }
579
-
580
- let width = max_u - min_u;
581
- let height = max_v - min_v;
582
- let area = width * height;
583
- if area >= best_area {
584
- continue;
585
- }
586
-
587
- best_area = area;
588
- best_quad = Some([
589
- [
590
- axis_u[0] * min_u + axis_v[0] * min_v,
591
- axis_u[1] * min_u + axis_v[1] * min_v,
592
- ],
593
- [
594
- axis_u[0] * max_u + axis_v[0] * min_v,
595
- axis_u[1] * max_u + axis_v[1] * min_v,
596
- ],
597
- [
598
- axis_u[0] * max_u + axis_v[0] * max_v,
599
- axis_u[1] * max_u + axis_v[1] * max_v,
600
- ],
601
- [
602
- axis_u[0] * min_u + axis_v[0] * max_v,
603
- axis_u[1] * min_u + axis_v[1] * max_v,
604
- ],
605
- ]);
606
- }
607
-
608
- best_quad
609
- }
610
-
611
- fn expand_quad(quad: &Quad, offset: f32) -> Quad {
612
- if offset <= 0.0 {
613
- return *quad;
614
- }
615
-
616
- let axis_u = [quad[1][0] - quad[0][0], quad[1][1] - quad[0][1]];
617
- let axis_v = [quad[3][0] - quad[0][0], quad[3][1] - quad[0][1]];
618
- let norm_u = vector_norm(axis_u).max(1e-6);
619
- let norm_v = vector_norm(axis_v).max(1e-6);
620
- let unit_u = [axis_u[0] / norm_u, axis_u[1] / norm_u];
621
- let unit_v = [axis_v[0] / norm_v, axis_v[1] / norm_v];
622
- let signs = [[-1.0, -1.0], [1.0, -1.0], [1.0, 1.0], [-1.0, 1.0]];
623
-
624
- let mut out = *quad;
625
- for (index, point) in out.iter_mut().enumerate() {
626
- point[0] += unit_u[0] * signs[index][0] * offset + unit_v[0] * signs[index][1] * offset;
627
- point[1] += unit_u[1] * signs[index][0] * offset + unit_v[1] * signs[index][1] * offset;
628
- }
629
- out
630
- }
631
-
632
- fn convex_hull(points: &[[f32; 2]]) -> Vec<[f32; 2]> {
633
- let mut points = points.to_vec();
634
- points.sort_by(|a, b| a[0].total_cmp(&b[0]).then_with(|| a[1].total_cmp(&b[1])));
635
- points.dedup_by(|a, b| (a[0] - b[0]).abs() < 1e-6 && (a[1] - b[1]).abs() < 1e-6);
636
- if points.len() <= 2 {
637
- return points;
638
- }
639
-
640
- let mut lower: Vec<[f32; 2]> = Vec::new();
641
- for point in &points {
642
- while lower.len() >= 2
643
- && cross_2d(
644
- [
645
- lower[lower.len() - 1][0] - lower[lower.len() - 2][0],
646
- lower[lower.len() - 1][1] - lower[lower.len() - 2][1],
647
- ],
648
- [
649
- point[0] - lower[lower.len() - 1][0],
650
- point[1] - lower[lower.len() - 1][1],
651
- ],
652
- ) <= 0.0
653
- {
654
- lower.pop();
655
- }
656
- lower.push(*point);
657
- }
658
-
659
- let mut upper: Vec<[f32; 2]> = Vec::new();
660
- for point in points.iter().rev() {
661
- while upper.len() >= 2
662
- && cross_2d(
663
- [
664
- upper[upper.len() - 1][0] - upper[upper.len() - 2][0],
665
- upper[upper.len() - 1][1] - upper[upper.len() - 2][1],
666
- ],
667
- [
668
- point[0] - upper[upper.len() - 1][0],
669
- point[1] - upper[upper.len() - 1][1],
670
- ],
671
- ) <= 0.0
672
- {
673
- upper.pop();
674
- }
675
- upper.push(*point);
676
- }
677
-
678
- lower.pop();
679
- upper.pop();
680
- lower.extend(upper);
681
- lower
682
- }
683
-
684
- fn polygon_area(quad: &Quad) -> f32 {
685
- let mut area = 0.0;
686
- for index in 0..quad.len() {
687
- let next = (index + 1) % quad.len();
688
- area += quad[index][0] * quad[next][1] - quad[next][0] * quad[index][1];
689
- }
690
- area.abs() * 0.5
691
- }
692
-
693
- fn polygon_perimeter(quad: &Quad) -> f32 {
694
- let mut perimeter = 0.0;
695
- for index in 0..quad.len() {
696
- let next = (index + 1) % quad.len();
697
- perimeter += vector_norm([
698
- quad[next][0] - quad[index][0],
699
- quad[next][1] - quad[index][1],
700
- ]);
701
  }
702
- perimeter
703
  }
704
 
705
- fn contour_score_fast(image: &ScoreMap, polygon: &[[f32; 2]]) -> f32 {
706
- if polygon.is_empty() {
707
- return 0.0;
708
- }
709
-
710
  let mut min_x = f32::MAX;
711
  let mut min_y = f32::MAX;
712
  let mut max_x = f32::MIN;
713
  let mut max_y = f32::MIN;
714
- for point in polygon {
715
  min_x = min_x.min(point[0]);
716
  min_y = min_y.min(point[1]);
717
  max_x = max_x.max(point[0]);
718
  max_y = max_y.max(point[1]);
719
  }
720
-
721
- let x1 = min_x
722
- .floor()
723
- .clamp(0.0, image.width.saturating_sub(1) as f32) as u32;
724
- let y1 = min_y
725
- .floor()
726
- .clamp(0.0, image.height.saturating_sub(1) as f32) as u32;
727
- let x2 = max_x.ceil().clamp(x1 as f32 + 1.0, image.width as f32) as u32;
728
- let y2 = max_y.ceil().clamp(y1 as f32 + 1.0, image.height as f32) as u32;
729
-
730
- let mut mask = GrayImage::new(x2 - x1, y2 - y1);
731
- let shifted = polygon
732
- .iter()
733
- .map(|point| {
734
- Point::new(
735
- (point[0] - x1 as f32).round() as i32,
736
- (point[1] - y1 as f32).round() as i32,
737
- )
738
- })
739
- .collect::<Vec<_>>();
740
- draw_polygon_mut(&mut mask, &shifted, Luma([1u8]));
741
-
742
- let mut sum = 0.0;
743
- let mut count = 0.0;
744
- for y in y1..y2 {
745
- for x in x1..x2 {
746
- if mask.get_pixel(x - x1, y - y1)[0] > 0 {
747
- sum += image.get(x, y);
748
- count += 1.0;
749
- }
750
- }
751
- }
752
-
753
- if count <= 0.0 { 0.0 } else { sum / count }
754
  }
755
 
756
- fn group_output(
757
- lines: &[DetectedLine],
758
- mask: &GrayImage,
759
- image_width: u32,
760
- image_height: u32,
761
- ) -> Vec<TextBlock> {
762
- let mut scattered_horizontal = Vec::new();
763
- let mut scattered_vertical = Vec::new();
764
-
765
- for line in lines {
766
- let line_bbox = quad_bbox(&line.quad);
767
- if mean_mask_score(mask, &line_bbox) >= MASK_SCORE_THRESHOLD {
768
- let mut block = CtdBlock::from_line(line);
769
- examine_block(&mut block, image_width, image_height, false);
770
- if block.source_direction == TextDirection::Vertical {
771
- scattered_vertical.push(block);
772
- } else {
773
- scattered_horizontal.push(block);
774
- }
775
- }
776
- }
777
-
778
- let mut final_blocks = Vec::new();
779
- scattered_vertical.sort_by(|a, b| b.center()[0].total_cmp(&a.center()[0]));
780
- scattered_horizontal.sort_by(|a, b| a.center()[1].total_cmp(&b.center()[1]));
781
-
782
- final_blocks.extend(merge_text_lines(scattered_horizontal, 2.0));
783
- final_blocks.extend(merge_text_lines(scattered_vertical, 1.7));
784
- final_blocks = merge_paragraph_blocks(final_blocks, image_width, image_height);
785
-
786
- let mut sorted = sort_regions(final_blocks);
787
- dedupe_blocks(&mut sorted);
788
- sorted.into_iter().map(block_to_text_block).collect()
789
  }
790
 
791
- fn block_to_text_block(block: CtdBlock) -> TextBlock {
792
- let width = (block.bbox[2] - block.bbox[0]).max(1.0);
793
- let height = (block.bbox[3] - block.bbox[1]).max(1.0);
794
- TextBlock {
795
- x: block.bbox[0],
796
- y: block.bbox[1],
797
- width,
798
- height,
799
- confidence: block.confidence,
800
- line_polygons: Some(block.lines),
801
- source_direction: Some(block.source_direction),
802
- source_language: Some(block.source_language),
803
- rotation_deg: Some(block.angle_deg),
804
- detected_font_size_px: Some(block.detected_font_size_px.max(1.0)),
805
- detector: Some("ctd".to_string()),
806
- ..Default::default()
807
- }
808
  }
809
 
810
- fn examine_block(block: &mut CtdBlock, image_width: u32, image_height: u32, sort: bool) {
811
- if block.lines.is_empty() {
812
- block.detected_font_size_px = block.bbox_height().min(block.bbox_width()).max(1.0);
813
- return;
814
- }
815
-
816
- let mut centers = Vec::with_capacity(block.lines.len());
817
- let mut vec_v_sum = [0.0f32, 0.0f32];
818
- let mut vec_h_sum = [0.0f32, 0.0f32];
819
- let mut font_acc = 0.0f32;
820
-
821
- for line in &block.lines {
822
- let middle = quad_midpoints(line);
823
- let vec_v = [middle[2][0] - middle[0][0], middle[2][1] - middle[0][1]];
824
- let vec_h = [middle[1][0] - middle[3][0], middle[1][1] - middle[3][1]];
825
- vec_v_sum[0] += vec_v[0];
826
- vec_v_sum[1] += vec_v[1];
827
- vec_h_sum[0] += vec_h[0];
828
- vec_h_sum[1] += vec_h[1];
829
- centers.push([
830
- (line[0][0] + line[2][0]) * 0.5,
831
- (line[0][1] + line[2][1]) * 0.5,
832
- ]);
833
- font_acc += match block.source_direction {
834
- TextDirection::Vertical => vector_norm(vec_h),
835
- TextDirection::Horizontal => vector_norm(vec_v),
836
- };
837
- }
838
-
839
- let (primary_vec, primary_norm) = match block.source_direction {
840
- TextDirection::Vertical => (vec_v_sum, vector_norm(vec_v_sum)),
841
- TextDirection::Horizontal => (vec_h_sum, vector_norm(vec_h_sum)),
842
- };
843
-
844
- block.detected_font_size_px = (font_acc / block.lines.len() as f32).max(1.0);
845
- block.direction_vec = primary_vec;
846
- block.direction_norm = primary_norm.max(1.0);
847
- block.distances = centers
848
- .iter()
849
- .map(|center| {
850
- let origin = match block.source_direction {
851
- TextDirection::Vertical => [image_width as f32, 0.0],
852
- TextDirection::Horizontal => [0.0, 0.0],
853
- };
854
- perpendicular_distance(
855
- [center[0] - origin[0], center[1] - origin[1]],
856
- primary_vec,
857
- image_height as f32,
858
- )
859
- })
860
- .collect();
861
-
862
- let mut angle = primary_vec[1].atan2(primary_vec[0]).to_degrees();
863
- if block.source_direction == TextDirection::Vertical {
864
- angle -= 90.0;
865
- }
866
- if angle.abs() < 3.0 {
867
- angle = 0.0;
868
- }
869
- block.angle_deg = angle;
870
- if sort {
871
- block.sort_lines();
872
- }
873
  }
874
 
875
- fn merge_text_lines(blocks: Vec<CtdBlock>, font_size_tol: f32) -> Vec<CtdBlock> {
876
- if blocks.len() < 2 {
877
- return blocks;
878
- }
879
-
880
- let mut blocks = blocks;
881
- let mut merged = Vec::new();
882
- for index in 0..blocks.len() {
883
- if blocks[index].merged {
884
- continue;
885
- }
886
- let mut current = blocks[index].clone();
887
- for other in blocks.iter_mut().skip(index + 1) {
888
- try_merge_text_line(&mut current, other, font_size_tol);
889
- }
890
- current.adjust_bbox(false);
891
- merged.push(current);
892
- }
893
- merged
894
  }
895
 
896
- fn merge_paragraph_blocks(
897
- blocks: Vec<CtdBlock>,
898
- image_width: u32,
899
- image_height: u32,
900
- ) -> Vec<CtdBlock> {
901
- if blocks.len() < 2 {
902
- return blocks;
903
- }
904
-
905
- let mut blocks = sort_regions(blocks);
906
- let mut merged = Vec::new();
907
- for index in 0..blocks.len() {
908
- if blocks[index].merged {
909
- continue;
910
- }
911
- let mut current = blocks[index].clone();
912
- while let Some(candidate_index) = find_paragraph_merge_candidate(&current, &blocks, index) {
913
- let other = &mut blocks[candidate_index];
914
- merge_paragraph_block(&mut current, other, image_width, image_height);
915
  }
916
- current.adjust_bbox(false);
917
- merged.push(current);
918
- }
919
- merged
920
  }
921
 
922
- fn find_paragraph_merge_candidate(
923
- current: &CtdBlock,
924
- blocks: &[CtdBlock],
925
- current_index: usize,
926
- ) -> Option<usize> {
927
- let mut best: Option<(usize, f32, f32)> = None;
928
- for candidate_index in current_index + 1..blocks.len() {
929
- let candidate = &blocks[candidate_index];
930
- let Some(gap_y) = paragraph_merge_gap(current, candidate) else {
931
- continue;
932
- };
933
- if paragraph_merge_blocked(current, candidate, blocks, current_index, candidate_index) {
934
- continue;
935
- }
936
 
937
- let overlap_x = horizontal_overlap(&current.bbox, &candidate.bbox);
938
- match best {
939
- Some((_, best_gap, best_overlap)) => {
940
- if gap_y < best_gap - 1e-3
941
- || ((gap_y - best_gap).abs() <= 1e-3 && overlap_x > best_overlap)
942
- {
943
- best = Some((candidate_index, gap_y, overlap_x));
944
- }
945
  }
946
- None => best = Some((candidate_index, gap_y, overlap_x)),
947
- }
 
 
 
948
  }
949
 
950
- best.map(|(index, _, _)| index)
951
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
952
 
953
- fn paragraph_merge_gap(block: &CtdBlock, other: &CtdBlock) -> Option<f32> {
954
- if other.merged
955
- || block.source_direction != TextDirection::Horizontal
956
- || other.source_direction != TextDirection::Horizontal
957
- || block.lines.is_empty()
958
- || other.lines.is_empty()
959
- {
960
- return None;
961
- }
962
-
963
- let count_a = block.lines.len() as f32;
964
- let count_b = other.lines.len() as f32;
965
- let font_avg = (block.detected_font_size_px * count_a + other.detected_font_size_px * count_b)
966
- / (count_a + count_b).max(1.0);
967
- if font_avg <= 0.0 {
968
- return None;
969
- }
970
-
971
- let font_ratio = block.detected_font_size_px / other.detected_font_size_px.max(1e-6);
972
- if font_ratio > 2.0 || font_ratio.recip() > 2.0 {
973
- return None;
974
- }
975
-
976
- let (upper, lower) = if block.center()[1] <= other.center()[1] {
977
- (block.bbox, other.bbox)
978
- } else {
979
- (other.bbox, block.bbox)
980
- };
981
- let gap_y = lower[1] - upper[3];
982
- if gap_y < -font_avg * 0.25 || gap_y > font_avg * 0.9 {
983
- return None;
984
- }
985
-
986
- let left_diff = (block.bbox[0] - other.bbox[0]).abs();
987
- let right_diff = (block.bbox[2] - other.bbox[2]).abs();
988
- let overlap_x = horizontal_overlap(&block.bbox, &other.bbox);
989
- let width_similarity = overlap_x / block.bbox_width().min(other.bbox_width()).max(1.0);
990
- let aligned_left = left_diff <= font_avg * 1.1;
991
- let aligned_right = right_diff <= font_avg * 1.1;
992
- let ragged_continue = left_diff <= font_avg * 1.8 && width_similarity >= 0.75;
993
- if !(aligned_left || aligned_right || ragged_continue) {
994
- return None;
995
- }
996
-
997
- let vec_prod = dot(block.direction_vec, other.direction_vec);
998
- let cos_vec = vec_prod / (block.direction_norm * other.direction_norm).max(1e-6);
999
- if cos_vec.abs() < 0.95 {
1000
- return None;
1001
- }
1002
-
1003
- let angle_diff = (block.angle_deg - other.angle_deg).abs();
1004
- if angle_diff > 8.0 {
1005
- return None;
1006
- }
1007
-
1008
- Some(gap_y.max(0.0))
1009
- }
1010
-
1011
- fn paragraph_merge_blocked(
1012
- block: &CtdBlock,
1013
- other: &CtdBlock,
1014
- blocks: &[CtdBlock],
1015
- current_index: usize,
1016
- candidate_index: usize,
1017
- ) -> bool {
1018
- let column_left = block.bbox[0].max(other.bbox[0]);
1019
- let column_right = block.bbox[2].min(other.bbox[2]);
1020
- if column_right <= column_left {
1021
- return false;
1022
- }
1023
-
1024
- let (upper, lower) = if block.center()[1] <= other.center()[1] {
1025
- (block.bbox, other.bbox)
1026
- } else {
1027
- (other.bbox, block.bbox)
1028
- };
1029
- let gap_top = upper[3];
1030
- let gap_bottom = lower[1];
1031
- if gap_bottom <= gap_top {
1032
- return false;
1033
- }
1034
-
1035
- for (index, candidate) in blocks.iter().enumerate() {
1036
- if index == current_index || index == candidate_index || candidate.merged {
1037
- continue;
1038
- }
1039
- if candidate.source_direction != TextDirection::Horizontal {
1040
- continue;
1041
- }
1042
- if candidate.bbox[3] <= gap_top || candidate.bbox[1] >= gap_bottom {
1043
- continue;
1044
- }
1045
- let candidate_overlap =
1046
- (candidate.bbox[2].min(column_right) - candidate.bbox[0].max(column_left)).max(0.0);
1047
- if candidate_overlap > 0.0 {
1048
- return true;
1049
- }
1050
- }
1051
-
1052
- false
1053
- }
1054
-
1055
- fn merge_paragraph_block(
1056
- block: &mut CtdBlock,
1057
- other: &mut CtdBlock,
1058
- image_width: u32,
1059
- image_height: u32,
1060
- ) -> bool {
1061
- let Some(_) = paragraph_merge_gap(block, other) else {
1062
- return false;
1063
- };
1064
- let top = block.bbox[1].min(other.bbox[1]);
1065
- let bottom = block.bbox[3].max(other.bbox[3]);
1066
- let count_a = block.lines.len() as f32;
1067
- let count_b = other.lines.len() as f32;
1068
- let font_avg = (block.detected_font_size_px * count_a + other.detected_font_size_px * count_b)
1069
- / (count_a + count_b).max(1.0);
1070
-
1071
- block.lines.extend(other.lines.iter().copied());
1072
- block.direction_vec = [
1073
- block.direction_vec[0] + other.direction_vec[0],
1074
- block.direction_vec[1] + other.direction_vec[1],
1075
- ];
1076
- block.direction_norm = vector_norm(block.direction_vec).max(1.0);
1077
- block.distances.extend(other.distances.iter().copied());
1078
- block.detected_font_size_px = font_avg.max(1.0);
1079
- block.confidence = block.confidence.max(other.confidence);
1080
- block.bbox = [
1081
- block.bbox[0].min(other.bbox[0]),
1082
- top,
1083
- block.bbox[2].max(other.bbox[2]),
1084
- bottom,
1085
- ];
1086
- other.merged = true;
1087
- examine_block(block, image_width, image_height, true);
1088
- true
1089
- }
1090
-
1091
- fn try_merge_text_line(block: &mut CtdBlock, other: &mut CtdBlock, font_size_tol: f32) -> bool {
1092
- if other.merged || block.lines.is_empty() || other.lines.is_empty() {
1093
- return false;
1094
- }
1095
- if block.detected_font_size_px <= 0.0 || other.detected_font_size_px <= 0.0 {
1096
- return false;
1097
- }
1098
-
1099
- let font_ratio = block.detected_font_size_px / other.detected_font_size_px;
1100
- let count_a = block.lines.len() as f32;
1101
- let count_b = other.lines.len() as f32;
1102
- let font_avg = (block.detected_font_size_px * count_a + other.detected_font_size_px * count_b)
1103
- / (count_a + count_b);
1104
- let vec_prod = dot(block.direction_vec, other.direction_vec);
1105
- let cos_vec = vec_prod / (block.direction_norm * other.direction_norm).max(1e-6);
1106
- let line_a = block.lines[block.lines.len() - 1];
1107
- let line_b = other.lines[0];
1108
- let bbox_a = quad_bbox(&line_a);
1109
- let bbox_b = quad_bbox(&line_b);
1110
- let distance_x = bbox_a[0].max(bbox_b[0]) - bbox_a[2].min(bbox_b[2]);
1111
- let distance_y = bbox_a[1].max(bbox_b[1]) - bbox_a[3].min(bbox_b[3]);
1112
- let w1 = (bbox_a[2] - bbox_a[0]).max(1.0);
1113
- let w2 = (bbox_b[2] - bbox_b[0]).max(1.0);
1114
- let h1 = (bbox_a[3] - bbox_a[1]).max(1.0);
1115
- let h2 = (bbox_b[3] - bbox_b[1]).max(1.0);
1116
-
1117
- if !quads_intersect(&line_a, &line_b) {
1118
- match block.source_direction {
1119
- TextDirection::Vertical => {
1120
- if distance_y > 0.0 {
1121
- return false;
1122
- }
1123
- if distance_x > font_avg * 0.8 {
1124
- return false;
1125
- }
1126
- if distance_y.abs() / h1.min(h2) < 0.4 {
1127
- return false;
1128
- }
1129
- }
1130
- TextDirection::Horizontal => {
1131
- if distance_x > 0.0 {
1132
- return false;
1133
- }
1134
- let width_similarity = (w1.min(w2) / w1.max(w2)).clamp(0.0, 1.0);
1135
- let mut font_threshold = if font_avg < 24.0 { 0.6 } else { 0.5 };
1136
- if width_similarity > 0.95 {
1137
- font_threshold -= 0.08;
1138
- } else if width_similarity < 0.88 {
1139
- font_threshold += 0.1;
1140
- }
1141
- if distance_y > font_avg * font_threshold {
1142
- return false;
1143
- }
1144
- if distance_x.abs() / w1.min(w2) < 0.3 {
1145
- return false;
1146
- }
1147
- }
1148
- }
1149
-
1150
- if font_ratio > font_size_tol || font_ratio.recip() > font_size_tol {
1151
- return false;
1152
- }
1153
- if cos_vec.abs() < 0.866 {
1154
- return false;
1155
- }
1156
- }
1157
-
1158
- block.lines.extend(other.lines.iter().copied());
1159
- block.direction_vec = [
1160
- block.direction_vec[0] + other.direction_vec[0],
1161
- block.direction_vec[1] + other.direction_vec[1],
1162
- ];
1163
- block.direction_norm = vector_norm(block.direction_vec).max(1.0);
1164
- block.angle_deg = block.direction_vec[1]
1165
- .atan2(block.direction_vec[0])
1166
- .to_degrees();
1167
- if block.source_direction == TextDirection::Vertical {
1168
- block.angle_deg -= 90.0;
1169
- }
1170
- block.distances.extend(other.distances.iter().copied());
1171
- block.detected_font_size_px = font_avg.max(1.0);
1172
- block.confidence = block.confidence.max(other.confidence);
1173
- other.merged = true;
1174
- true
1175
- }
1176
-
1177
- fn sort_regions(blocks: Vec<CtdBlock>) -> Vec<CtdBlock> {
1178
- if blocks.len() < 2 {
1179
- return blocks;
1180
- }
1181
-
1182
- let vertical_blocks = blocks
1183
- .iter()
1184
- .filter(|block| block.source_direction == TextDirection::Vertical)
1185
- .count();
1186
- let right_to_left = !blocks.is_empty() && vertical_blocks * 2 >= blocks.len();
1187
- let order = stable_reading_order_indices(&blocks, right_to_left);
1188
- let mut ordered = Vec::with_capacity(blocks.len());
1189
- let mut slots = blocks.into_iter().map(Some).collect::<Vec<_>>();
1190
- for index in order {
1191
- if let Some(block) = slots[index].take() {
1192
- ordered.push(block);
1193
- }
1194
- }
1195
- ordered
1196
- }
1197
-
1198
- fn compare_blocks_for_reading_order(
1199
- a: &CtdBlock,
1200
- b: &CtdBlock,
1201
- right_to_left: bool,
1202
- ) -> std::cmp::Ordering {
1203
- let overlap_y = vertical_overlap(&a.bbox, &b.bbox);
1204
- let row_tolerance = (a.detected_font_size_px.max(b.detected_font_size_px) * 0.6).max(1.0);
1205
- if overlap_y > 0.0 || (a.center()[1] - b.center()[1]).abs() <= row_tolerance {
1206
- if right_to_left {
1207
- b.center()[0]
1208
- .total_cmp(&a.center()[0])
1209
- .then_with(|| a.center()[1].total_cmp(&b.center()[1]))
1210
- } else {
1211
- a.center()[0]
1212
- .total_cmp(&b.center()[0])
1213
- .then_with(|| a.center()[1].total_cmp(&b.center()[1]))
1214
- }
1215
- } else {
1216
- a.center()[1]
1217
- .total_cmp(&b.center()[1])
1218
- .then_with(|| a.center()[0].total_cmp(&b.center()[0]))
1219
- }
1220
- }
1221
-
1222
- fn stable_reading_order_indices(blocks: &[CtdBlock], right_to_left: bool) -> Vec<usize> {
1223
- let mut edges = vec![Vec::new(); blocks.len()];
1224
- let mut indegree = vec![0usize; blocks.len()];
1225
- for left in 0..blocks.len() {
1226
- for right in (left + 1)..blocks.len() {
1227
- match compare_blocks_for_reading_order(&blocks[left], &blocks[right], right_to_left) {
1228
- std::cmp::Ordering::Less => {
1229
- edges[left].push(right);
1230
- indegree[right] += 1;
1231
- }
1232
- std::cmp::Ordering::Greater => {
1233
- edges[right].push(left);
1234
- indegree[left] += 1;
1235
- }
1236
- std::cmp::Ordering::Equal => {}
1237
- }
1238
- }
1239
- }
1240
-
1241
- let mut remaining = (0..blocks.len()).collect::<Vec<_>>();
1242
- let mut ordered = Vec::with_capacity(blocks.len());
1243
- while !remaining.is_empty() {
1244
- let available = remaining
1245
- .iter()
1246
- .copied()
1247
- .filter(|index| indegree[*index] == 0)
1248
- .collect::<Vec<_>>();
1249
- let next = available.first().copied().unwrap_or_else(|| remaining[0]);
1250
- ordered.push(next);
1251
- remaining.retain(|index| *index != next);
1252
- for successor in &edges[next] {
1253
- indegree[*successor] = indegree[*successor].saturating_sub(1);
1254
- }
1255
- }
1256
-
1257
- ordered
1258
- }
1259
- fn dedupe_blocks(blocks: &mut Vec<CtdBlock>) {
1260
- if blocks.len() < 2 {
1261
- return;
1262
- }
1263
-
1264
- let mut deduped = vec![blocks[0].clone()];
1265
- for block in blocks.iter().skip(1) {
1266
- let area = bbox_area(&block.bbox).max(1e-6);
1267
- let mut keep = true;
1268
- for existing in &deduped {
1269
- let intersection = overlap_area(&block.bbox, &existing.bbox);
1270
- if intersection / area > 0.9 {
1271
- keep = false;
1272
- break;
1273
- }
1274
- }
1275
- if keep {
1276
- deduped.push(block.clone());
1277
- }
1278
- }
1279
- *blocks = deduped;
1280
- }
1281
-
1282
- fn refine_mask(image: &RgbImage, pred_mask: &GrayImage, blocks: &[TextBlock]) -> GrayImage {
1283
- let mut refined = GrayImage::new(pred_mask.width(), pred_mask.height());
1284
- for block in blocks {
1285
- let bbox = [
1286
- block.x,
1287
- block.y,
1288
- block.x + block.width,
1289
- block.y + block.height,
1290
- ];
1291
- let [x1f, y1f, x2f, y2f] =
1292
- enlarge_window(bbox, image.width() as f32, image.height() as f32);
1293
- if x2f <= x1f || y2f <= y1f {
1294
- continue;
1295
- }
1296
-
1297
- let x1 = x1f as u32;
1298
- let y1 = y1f as u32;
1299
- let width = x2f as u32 - x1;
1300
- let height = y2f as u32 - y1;
1301
- let rgb_crop = imageops::crop_imm(image, x1, y1, width, height).to_image();
1302
- let mask_crop = imageops::crop_imm(pred_mask, x1, y1, width, height).to_image();
1303
- let mut candidates = topk_mask_candidates(&rgb_crop, &mask_crop);
1304
- candidates.extend(otsu_mask_candidates(&rgb_crop, &mask_crop));
1305
- let merged = merge_mask_candidates(candidates, &mask_crop);
1306
-
1307
- for local_y in 0..height {
1308
- for local_x in 0..width {
1309
- if merged.get_pixel(local_x, local_y)[0] > 0 {
1310
- refined.put_pixel(x1 + local_x, y1 + local_y, Luma([255]));
1311
- }
1312
- }
1313
- }
1314
- }
1315
- refined
1316
- }
1317
-
1318
- fn topk_mask_candidates(image: &RgbImage, pred_mask: &GrayImage) -> Vec<CandidateMask> {
1319
- let eroded = erode(pred_mask, Norm::LInf, 1);
1320
- let gray = DynamicImage::ImageRgb8(image.clone())
1321
- .grayscale()
1322
- .to_luma8();
1323
- let mut histogram = [0u32; 256];
1324
- let mut total = 0u32;
1325
- for (pixel, mask_pixel) in gray.pixels().zip(eroded.pixels()) {
1326
- if mask_pixel[0] > 127 {
1327
- histogram[pixel[0] as usize] += 1;
1328
- total += 1;
1329
- }
1330
- }
1331
- if total == 0 {
1332
- return Vec::new();
1333
- }
1334
-
1335
- let mut colors: Vec<(u8, u32)> = histogram
1336
- .iter()
1337
- .enumerate()
1338
- .filter_map(|(index, count)| {
1339
- if *count > 0 {
1340
- Some((index as u8, *count))
1341
- } else {
1342
- None
1343
- }
1344
- })
1345
- .collect();
1346
- colors.sort_by(|a, b| b.1.cmp(&a.1));
1347
-
1348
- let mut top_colors = Vec::new();
1349
- let bin_tol = (total as f32 * 0.001).ceil() as u32;
1350
- for (color, count) in colors {
1351
- if top_colors
1352
- .iter()
1353
- .all(|existing: &u8| existing.abs_diff(color) > 10)
1354
- {
1355
- top_colors.push(color);
1356
- }
1357
- if top_colors.len() >= 3 || count < bin_tol {
1358
- break;
1359
- }
1360
- }
1361
-
1362
- top_colors
1363
- .into_iter()
1364
- .map(|color| {
1365
- let top = color.saturating_add(30);
1366
- let bottom = color.saturating_sub(30);
1367
- let thresholded = GrayImage::from_fn(gray.width(), gray.height(), |x, y| {
1368
- let value = gray.get_pixel(x, y)[0];
1369
- if value >= bottom && value <= top {
1370
- Luma([255u8])
1371
- } else {
1372
- Luma([0u8])
1373
- }
1374
- });
1375
- minxor_threshold(thresholded, pred_mask)
1376
- })
1377
- .collect()
1378
- }
1379
-
1380
- fn otsu_mask_candidates(image: &RgbImage, pred_mask: &GrayImage) -> Vec<CandidateMask> {
1381
- let mut candidates = Vec::new();
1382
- for channel in 0..3 {
1383
- let channel_image = GrayImage::from_fn(image.width(), image.height(), |x, y| {
1384
- Luma([image.get_pixel(x, y)[channel]])
1385
- });
1386
- let level = otsu_level(&channel_image);
1387
- let thresholded =
1388
- GrayImage::from_fn(channel_image.width(), channel_image.height(), |x, y| {
1389
- if channel_image.get_pixel(x, y)[0] > level {
1390
- Luma([255u8])
1391
- } else {
1392
- Luma([0u8])
1393
- }
1394
- });
1395
- candidates.push(minxor_threshold(thresholded, pred_mask));
1396
- }
1397
- candidates.sort_by(|a, b| a.xor_sum.cmp(&b.xor_sum));
1398
- candidates.into_iter().take(1).collect()
1399
- }
1400
-
1401
- fn minxor_threshold(thresholded: GrayImage, pred_mask: &GrayImage) -> CandidateMask {
1402
- let inverted = invert_binary(&thresholded);
1403
- let regular_xor = xor_sum(&thresholded, pred_mask);
1404
- let inverted_xor = xor_sum(&inverted, pred_mask);
1405
- if inverted_xor < regular_xor {
1406
- CandidateMask {
1407
- mask: inverted,
1408
- xor_sum: inverted_xor,
1409
- }
1410
- } else {
1411
- CandidateMask {
1412
- mask: thresholded,
1413
- xor_sum: regular_xor,
1414
- }
1415
- }
1416
- }
1417
-
1418
- fn merge_mask_candidates(mut candidates: Vec<CandidateMask>, pred_mask: &GrayImage) -> GrayImage {
1419
- candidates.sort_by(|a, b| a.xor_sum.cmp(&b.xor_sum));
1420
- let mut mask_merged = GrayImage::new(pred_mask.width(), pred_mask.height());
1421
- let pred = threshold_binary(&erode(pred_mask, Norm::LInf, 1), 60);
1422
-
1423
- for candidate in candidates {
1424
- let components = connected_components_stats(&candidate.mask, Connectivity::Eight);
1425
- for component in components {
1426
- if component.w * component.h < 3 {
1427
- continue;
1428
- }
1429
-
1430
- let current = imageops::crop_imm(
1431
- &mask_merged,
1432
- component.x,
1433
- component.y,
1434
- component.w,
1435
- component.h,
1436
- )
1437
- .to_image();
1438
- let mut combined = current.clone();
1439
- for (x, y) in &component.pixels {
1440
- combined.put_pixel(*x - component.x, *y - component.y, Luma([255]));
1441
- }
1442
-
1443
- let pred_crop =
1444
- imageops::crop_imm(&pred, component.x, component.y, component.w, component.h)
1445
- .to_image();
1446
- if xor_sum(&combined, &pred_crop) < xor_sum(&current, &pred_crop) {
1447
- for local_y in 0..component.h {
1448
- for local_x in 0..component.w {
1449
- let pixel = combined.get_pixel(local_x, local_y);
1450
- if pixel[0] > 0 {
1451
- mask_merged.put_pixel(
1452
- component.x + local_x,
1453
- component.y + local_y,
1454
- *pixel,
1455
- );
1456
- }
1457
- }
1458
- }
1459
- }
1460
- }
1461
- }
1462
-
1463
- let mut mask_merged = dilate(&mask_merged, Norm::LInf, 2);
1464
- let inverted = invert_binary(&mask_merged);
1465
- let holes = connected_components_stats(&inverted, Connectivity::Eight);
1466
- if !holes.is_empty() {
1467
- let mut areas: Vec<u32> = holes.iter().map(|component| component.area).collect();
1468
- areas.sort_unstable();
1469
- let area_threshold = if areas.len() > 1 {
1470
- areas[areas.len() - 2]
1471
- } else {
1472
- areas[0]
1473
- };
1474
-
1475
- for component in holes {
1476
- if component.area >= area_threshold {
1477
- continue;
1478
- }
1479
-
1480
- let current = imageops::crop_imm(
1481
- &mask_merged,
1482
- component.x,
1483
- component.y,
1484
- component.w,
1485
- component.h,
1486
- )
1487
- .to_image();
1488
- let mut combined = current.clone();
1489
- for (x, y) in &component.pixels {
1490
- combined.put_pixel(*x - component.x, *y - component.y, Luma([255]));
1491
- }
1492
- let pred_crop =
1493
- imageops::crop_imm(&pred, component.x, component.y, component.w, component.h)
1494
- .to_image();
1495
- if xor_sum(&combined, &pred_crop) < xor_sum(&current, &pred_crop) {
1496
- for local_y in 0..component.h {
1497
- for local_x in 0..component.w {
1498
- let pixel = combined.get_pixel(local_x, local_y);
1499
- if pixel[0] > 0 {
1500
- mask_merged.put_pixel(
1501
- component.x + local_x,
1502
- component.y + local_y,
1503
- *pixel,
1504
- );
1505
- }
1506
- }
1507
- }
1508
- }
1509
- }
1510
- }
1511
-
1512
- mask_merged
1513
- }
1514
-
1515
- fn connected_components_stats(image: &GrayImage, connectivity: Connectivity) -> Vec<Component> {
1516
- let labels = connected_components(image, connectivity, Luma([0u8]));
1517
- let max_label = labels.pixels().map(|pixel| pixel[0]).max().unwrap_or(0);
1518
- let mut components = vec![
1519
- Component {
1520
- x: 0,
1521
- y: 0,
1522
- w: 0,
1523
- h: 0,
1524
- area: 0,
1525
- pixels: Vec::new(),
1526
- };
1527
- (max_label + 1) as usize
1528
- ];
1529
-
1530
- for component in components.iter_mut().skip(1) {
1531
- component.x = u32::MAX;
1532
- component.y = u32::MAX;
1533
- }
1534
-
1535
- for y in 0..labels.height() {
1536
- for x in 0..labels.width() {
1537
- let label = labels.get_pixel(x, y)[0];
1538
- if label == 0 {
1539
- continue;
1540
- }
1541
- let component = &mut components[label as usize];
1542
- component.area += 1;
1543
- component.x = component.x.min(x);
1544
- component.y = component.y.min(y);
1545
- component.w = component.w.max(x);
1546
- component.h = component.h.max(y);
1547
- component.pixels.push((x, y));
1548
- }
1549
- }
1550
-
1551
- for component in components.iter_mut().skip(1) {
1552
- if component.pixels.is_empty() {
1553
- continue;
1554
- }
1555
- component.w = component.w.saturating_sub(component.x) + 1;
1556
- component.h = component.h.saturating_sub(component.y) + 1;
1557
- }
1558
-
1559
- components
1560
- .into_iter()
1561
- .skip(1)
1562
- .filter(|component| component.area > 0)
1563
- .collect()
1564
- }
1565
-
1566
- fn quad_midpoints(quad: &Quad) -> Quad {
1567
- [
1568
- midpoint(quad[0], quad[1]),
1569
- midpoint(quad[1], quad[2]),
1570
- midpoint(quad[2], quad[3]),
1571
- midpoint(quad[3], quad[0]),
1572
- ]
1573
- }
1574
-
1575
- fn quad_axis_lengths(quad: &Quad) -> (f32, f32) {
1576
- let midpoints = quad_midpoints(quad);
1577
- let vec_v = [
1578
- midpoints[2][0] - midpoints[0][0],
1579
- midpoints[2][1] - midpoints[0][1],
1580
- ];
1581
- let vec_h = [
1582
- midpoints[1][0] - midpoints[3][0],
1583
- midpoints[1][1] - midpoints[3][1],
1584
- ];
1585
- (vector_norm(vec_v), vector_norm(vec_h))
1586
- }
1587
-
1588
- fn midpoint(a: [f32; 2], b: [f32; 2]) -> [f32; 2] {
1589
- [(a[0] + b[0]) * 0.5, (a[1] + b[1]) * 0.5]
1590
- }
1591
-
1592
- fn quad_to_tuples(quad: &Quad) -> [(f32, f32); 4] {
1593
- [
1594
- (quad[0][0], quad[0][1]),
1595
- (quad[1][0], quad[1][1]),
1596
- (quad[2][0], quad[2][1]),
1597
- (quad[3][0], quad[3][1]),
1598
- ]
1599
- }
1600
-
1601
- fn reorder_quad_horizontal(quad: &Quad) -> Quad {
1602
- let mut points = *quad;
1603
- points.sort_by(|a, b| a[0].total_cmp(&b[0]).then_with(|| a[1].total_cmp(&b[1])));
1604
- let mut left = [points[0], points[1]];
1605
- let mut right = [points[2], points[3]];
1606
- left.sort_by(|a, b| a[1].total_cmp(&b[1]));
1607
- right.sort_by(|a, b| a[1].total_cmp(&b[1]));
1608
- [left[0], right[0], right[1], left[1]]
1609
- }
1610
-
1611
- fn reorder_quad_vertical(quad: &Quad) -> Quad {
1612
- let mut points = *quad;
1613
- points.sort_by(|a, b| a[1].total_cmp(&b[1]).then_with(|| a[0].total_cmp(&b[0])));
1614
- let mut top = [points[0], points[1]];
1615
- let mut bottom = [points[2], points[3]];
1616
- top.sort_by(|a, b| a[0].total_cmp(&b[0]));
1617
- bottom.sort_by(|a, b| b[0].total_cmp(&a[0]));
1618
- [top[0], top[1], bottom[0], bottom[1]]
1619
- }
1620
-
1621
- fn sort_quad_points(quad: &Quad) -> (Quad, bool) {
1622
- let mut pairwise = Vec::with_capacity(16);
1623
- for a in quad {
1624
- for b in quad {
1625
- let vec = [a[0] - b[0], a[1] - b[1]];
1626
- pairwise.push((vec, vector_norm(vec)));
1627
- }
1628
- }
1629
-
1630
- let mut sorted_ids: Vec<usize> = (0..pairwise.len()).collect();
1631
- sorted_ids.sort_by(|a, b| pairwise[*a].1.total_cmp(&pairwise[*b].1));
1632
- let mut long_side_vecs = [pairwise[sorted_ids[8]].0, pairwise[sorted_ids[10]].0];
1633
- if dot(long_side_vecs[0], long_side_vecs[1]) < 0.0 {
1634
- long_side_vecs[0] = [-long_side_vecs[0][0], -long_side_vecs[0][1]];
1635
- }
1636
-
1637
- let structure_vec = [
1638
- (long_side_vecs[0][0] + long_side_vecs[1][0]).abs() * 0.5,
1639
- (long_side_vecs[0][1] + long_side_vecs[1][1]).abs() * 0.5,
1640
- ];
1641
- let sorted_norms: Vec<f32> = sorted_ids.iter().map(|id| pairwise[*id].1).collect();
1642
- let square = sorted_norms[4..12]
1643
- .iter()
1644
- .copied()
1645
- .fold((f32::MAX, f32::MIN), |(min_v, max_v), value| {
1646
- (min_v.min(value), max_v.max(value))
1647
- });
1648
- let mut vertical = structure_vec[0] * 1.2 <= structure_vec[1];
1649
- if square.1 - square.0 < 1e-3 {
1650
- vertical = false;
1651
- }
1652
-
1653
- if vertical {
1654
- (reorder_quad_vertical(quad), true)
1655
- } else {
1656
- (reorder_quad_horizontal(quad), false)
1657
- }
1658
- }
1659
-
1660
- fn clip_quad(quad: &Quad, width: f32, height: f32) -> Quad {
1661
- let mut clipped = *quad;
1662
- for point in &mut clipped {
1663
- point[0] = point[0].clamp(0.0, width);
1664
- point[1] = point[1].clamp(0.0, height);
1665
- }
1666
- clipped
1667
- }
1668
-
1669
- fn scale_quad(quad: &Quad, scale_x: f32, scale_y: f32) -> Quad {
1670
- let mut scaled = *quad;
1671
- for point in &mut scaled {
1672
- point[0] *= scale_x;
1673
- point[1] *= scale_y;
1674
- }
1675
- scaled
1676
- }
1677
-
1678
- fn quad_bbox(quad: &Quad) -> [f32; 4] {
1679
- let mut min_x = f32::MAX;
1680
- let mut min_y = f32::MAX;
1681
- let mut max_x = f32::MIN;
1682
- let mut max_y = f32::MIN;
1683
- for point in quad {
1684
- min_x = min_x.min(point[0]);
1685
- min_y = min_y.min(point[1]);
1686
- max_x = max_x.max(point[0]);
1687
- max_y = max_y.max(point[1]);
1688
- }
1689
- [min_x, min_y, max_x, max_y]
1690
- }
1691
-
1692
- fn bbox_area(bbox: &[f32; 4]) -> f32 {
1693
- (bbox[2] - bbox[0]).max(0.0) * (bbox[3] - bbox[1]).max(0.0)
1694
- }
1695
-
1696
- fn overlap_area(a: &[f32; 4], b: &[f32; 4]) -> f32 {
1697
- let x1 = a[0].max(b[0]);
1698
- let y1 = a[1].max(b[1]);
1699
- let x2 = a[2].min(b[2]);
1700
- let y2 = a[3].min(b[3]);
1701
- if x2 <= x1 || y2 <= y1 {
1702
- return 0.0;
1703
- }
1704
- (x2 - x1) * (y2 - y1)
1705
- }
1706
-
1707
- fn horizontal_overlap(a: &[f32; 4], b: &[f32; 4]) -> f32 {
1708
- (a[2].min(b[2]) - a[0].max(b[0])).max(0.0)
1709
- }
1710
-
1711
- fn vertical_overlap(a: &[f32; 4], b: &[f32; 4]) -> f32 {
1712
- (a[3].min(b[3]) - a[1].max(b[1])).max(0.0)
1713
- }
1714
-
1715
- fn mean_mask_score(mask: &GrayImage, bbox: &[f32; 4]) -> f32 {
1716
- let x1 = bbox[0].floor().max(0.0) as u32;
1717
- let y1 = bbox[1].floor().max(0.0) as u32;
1718
- let x2 = bbox[2].ceil().min(mask.width() as f32) as u32;
1719
- let y2 = bbox[3].ceil().min(mask.height() as f32) as u32;
1720
- if x2 <= x1 || y2 <= y1 {
1721
- return 0.0;
1722
- }
1723
-
1724
- let mut sum = 0u64;
1725
- let mut count = 0u64;
1726
- for y in y1..y2 {
1727
- for x in x1..x2 {
1728
- sum += mask.get_pixel(x, y)[0] as u64;
1729
- count += 1;
1730
- }
1731
- }
1732
- if count == 0 {
1733
- 0.0
1734
- } else {
1735
- (sum as f32 / count as f32) / 255.0
1736
- }
1737
- }
1738
-
1739
- fn enlarge_window(bbox: [f32; 4], image_width: f32, image_height: f32) -> [f32; 4] {
1740
- let w = bbox[2] - bbox[0];
1741
- let h = bbox[3] - bbox[1];
1742
- if w <= 0.0 || h <= 0.0 {
1743
- return [0.0, 0.0, 0.0, 0.0];
1744
- }
1745
-
1746
- let a = 1.0f32;
1747
- let b = w + h;
1748
- let c = (1.0 - 2.5) * w * h;
1749
- let delta = (b * b - 4.0 * a * c).max(0.0).sqrt();
1750
- let grow = ((-b + delta) / (2.0 * a)).max(0.0) * 0.5;
1751
- let grow_x = grow.min(bbox[0]).min(image_width - bbox[2]);
1752
- let grow_y = grow.min(bbox[1]).min(image_height - bbox[3]);
1753
-
1754
- [
1755
- (bbox[0] - grow_x).clamp(0.0, image_width),
1756
- (bbox[1] - grow_y).clamp(0.0, image_height),
1757
- (bbox[2] + grow_x).clamp(0.0, image_width),
1758
- (bbox[3] + grow_y).clamp(0.0, image_height),
1759
- ]
1760
- }
1761
-
1762
- fn invert_binary(image: &GrayImage) -> GrayImage {
1763
- GrayImage::from_fn(image.width(), image.height(), |x, y| {
1764
- if image.get_pixel(x, y)[0] > 0 {
1765
- Luma([0u8])
1766
- } else {
1767
- Luma([255u8])
1768
- }
1769
- })
1770
- }
1771
-
1772
- fn threshold_binary(image: &GrayImage, threshold: u8) -> GrayImage {
1773
- GrayImage::from_fn(image.width(), image.height(), |x, y| {
1774
- if image.get_pixel(x, y)[0] > threshold {
1775
- Luma([255u8])
1776
- } else {
1777
- Luma([0u8])
1778
- }
1779
- })
1780
- }
1781
-
1782
- fn xor_sum(a: &GrayImage, b: &GrayImage) -> u64 {
1783
- a.pixels()
1784
- .zip(b.pixels())
1785
- .map(|(left, right)| (left[0] ^ right[0]) as u64)
1786
- .sum()
1787
- }
1788
-
1789
- fn quads_intersect(a: &Quad, b: &Quad) -> bool {
1790
- let mut axes = Vec::with_capacity(8);
1791
- axes.extend(quad_axes(a));
1792
- axes.extend(quad_axes(b));
1793
-
1794
- for axis in axes {
1795
- let (a_min, a_max) = project_quad(a, axis);
1796
- let (b_min, b_max) = project_quad(b, axis);
1797
- if a_max < b_min || b_max < a_min {
1798
- return false;
1799
- }
1800
- }
1801
- true
1802
- }
1803
-
1804
- fn quad_axes(quad: &Quad) -> Vec<[f32; 2]> {
1805
- let mut axes = Vec::with_capacity(4);
1806
- for index in 0..4 {
1807
- let next = (index + 1) % 4;
1808
- let edge = [
1809
- quad[next][0] - quad[index][0],
1810
- quad[next][1] - quad[index][1],
1811
- ];
1812
- let normal = [-edge[1], edge[0]];
1813
- let norm = vector_norm(normal);
1814
- if norm > 0.0 {
1815
- axes.push([normal[0] / norm, normal[1] / norm]);
1816
- }
1817
- }
1818
- axes
1819
- }
1820
-
1821
- fn project_quad(quad: &Quad, axis: [f32; 2]) -> (f32, f32) {
1822
- let mut min = f32::MAX;
1823
- let mut max = f32::MIN;
1824
- for point in quad {
1825
- let projection = point[0] * axis[0] + point[1] * axis[1];
1826
- min = min.min(projection);
1827
- max = max.max(projection);
1828
- }
1829
- (min, max)
1830
- }
1831
-
1832
- fn perpendicular_distance(vector: [f32; 2], axis: [f32; 2], _unused: f32) -> f32 {
1833
- let axis_norm = vector_norm(axis).max(1e-6);
1834
- let dot = vector[0] * axis[0] + vector[1] * axis[1];
1835
- let vector_norm = vector_norm(vector).max(1e-6);
1836
- let cos = (dot / (vector_norm * axis_norm)).clamp(-1.0, 1.0);
1837
- (1.0 - cos * cos).max(0.0).sqrt() * vector_norm
1838
- }
1839
-
1840
- fn vector_norm(vector: [f32; 2]) -> f32 {
1841
- (vector[0] * vector[0] + vector[1] * vector[1]).sqrt()
1842
- }
1843
-
1844
- fn cross_2d(a: [f32; 2], b: [f32; 2]) -> f32 {
1845
- a[0] * b[1] - a[1] * b[0]
1846
- }
1847
-
1848
- fn dot(a: [f32; 2], b: [f32; 2]) -> f32 {
1849
- a[0] * b[0] + a[1] * b[1]
1850
- }
1851
-
1852
- trait BboxExt {
1853
- fn bbox_width(&self) -> f32;
1854
- fn bbox_height(&self) -> f32;
1855
- }
1856
-
1857
- impl BboxExt for CtdBlock {
1858
- fn bbox_width(&self) -> f32 {
1859
- self.bbox[2] - self.bbox[0]
1860
- }
1861
-
1862
- fn bbox_height(&self) -> f32 {
1863
- self.bbox[3] - self.bbox[1]
1864
- }
1865
- }
1866
-
1867
- #[cfg(test)]
1868
- mod tests {
1869
- use super::*;
1870
-
1871
- fn test_block(bbox: [f32; 4], source_direction: TextDirection) -> CtdBlock {
1872
- CtdBlock {
1873
- bbox,
1874
- confidence: 0.9,
1875
- source_language: "unknown".to_string(),
1876
- source_direction,
1877
- lines: Vec::new(),
1878
- angle_deg: 0.0,
1879
- detected_font_size_px: 10.0,
1880
- distances: Vec::new(),
1881
- direction_vec: [1.0, 0.0],
1882
- direction_norm: 1.0,
1883
- merged: false,
1884
- }
1885
- }
1886
-
1887
- #[test]
1888
- fn component_quad_tracks_orientation() {
1889
- let component = Component {
1890
- x: 0,
1891
- y: 0,
1892
- w: 30,
1893
- h: 8,
1894
- area: 40,
1895
- pixels: (0..20).map(|x| (x, 2 + (x / 5))).collect(),
1896
- };
1897
-
1898
- let (quad, vertical) = component_quad(&component).expect("quad");
1899
- assert!(!vertical);
1900
- let bbox = quad_bbox(&quad);
1901
- assert!(bbox[2] > bbox[0]);
1902
- assert!(bbox[3] > bbox[1]);
1903
- }
1904
-
1905
- #[test]
1906
- fn merge_text_lines_keeps_adjacent_horizontal_lines() {
1907
- let line_a = DetectedLine {
1908
- quad: [[0.0, 0.0], [20.0, 0.0], [20.0, 8.0], [0.0, 8.0]],
1909
- vertical: false,
1910
- score: 0.9,
1911
- };
1912
- let line_b = DetectedLine {
1913
- quad: [[18.0, 0.0], [38.0, 0.0], [38.0, 8.0], [18.0, 8.0]],
1914
- vertical: false,
1915
- score: 0.9,
1916
- };
1917
- let mut block_a = CtdBlock::from_line(&line_a);
1918
- let mut block_b = CtdBlock::from_line(&line_b);
1919
- examine_block(&mut block_a, 100, 100, true);
1920
- examine_block(&mut block_b, 100, 100, true);
1921
-
1922
- assert!(try_merge_text_line(&mut block_a, &mut block_b, 2.0));
1923
- assert_eq!(block_a.lines.len(), 2);
1924
  }
1925
 
1926
  #[test]
1927
- fn transformed_regions_fall_back_to_bbox_without_ctd_metadata() {
1928
- let image = DynamicImage::ImageRgb8(RgbImage::from_pixel(32, 32, Rgb([255, 255, 255])));
1929
  let block = TextBlock {
1930
  x: 4.0,
1931
- y: 6.0,
1932
  width: 10.0,
1933
- height: 12.0,
1934
  ..Default::default()
1935
  };
1936
 
1937
  let regions = extract_text_block_regions(&image, &block);
1938
  assert_eq!(regions.len(), 1);
1939
  assert_eq!(regions[0].width(), 10);
1940
- assert_eq!(regions[0].height(), 12);
1941
  }
1942
 
1943
  #[test]
@@ -1965,233 +379,4 @@ mod tests {
1965
  assert!(crop.width() > 12);
1966
  assert!(crop.height() > 8);
1967
  }
1968
-
1969
- #[test]
1970
- fn group_output_builds_line_only_ctd_blocks() {
1971
- let line = DetectedLine {
1972
- quad: [[10.0, 10.0], [30.0, 10.0], [30.0, 18.0], [10.0, 18.0]],
1973
- vertical: false,
1974
- score: 0.95,
1975
- };
1976
- let mask = GrayImage::from_fn(48, 48, |x, y| {
1977
- if (10..30).contains(&x) && (10..18).contains(&y) {
1978
- Luma([255])
1979
- } else {
1980
- Luma([0])
1981
- }
1982
- });
1983
-
1984
- let blocks = group_output(&[line], &mask, 48, 48);
1985
- assert_eq!(blocks.len(), 1);
1986
- assert_eq!(blocks[0].detector.as_deref(), Some("ctd"));
1987
- assert_eq!(blocks[0].source_direction, Some(TextDirection::Horizontal));
1988
- assert_eq!(blocks[0].line_polygons.as_ref().map(Vec::len), Some(1));
1989
- assert!(blocks[0].detected_font_size_px.unwrap_or_default() > 0.0);
1990
- }
1991
-
1992
- #[test]
1993
- fn transformed_regions_rotate_vertical_ctd_lines() {
1994
- let mut image = RgbImage::from_pixel(48, 48, Rgb([255, 255, 255]));
1995
- for y in 8..40 {
1996
- for x in 20..28 {
1997
- image.put_pixel(x, y, Rgb([0, 0, 0]));
1998
- }
1999
- }
2000
- let block = TextBlock {
2001
- x: 18.0,
2002
- y: 8.0,
2003
- width: 12.0,
2004
- height: 32.0,
2005
- line_polygons: Some(vec![[[20.0, 8.0], [28.0, 8.0], [28.0, 40.0], [20.0, 40.0]]]),
2006
- source_direction: Some(TextDirection::Vertical),
2007
- rotation_deg: Some(0.0),
2008
- detected_font_size_px: Some(8.0),
2009
- detector: Some("ctd".to_string()),
2010
- ..Default::default()
2011
- };
2012
-
2013
- let regions = extract_text_block_regions(&DynamicImage::ImageRgb8(image), &block);
2014
- assert_eq!(regions.len(), 1);
2015
- assert!(regions[0].width() > regions[0].height());
2016
- }
2017
-
2018
- #[test]
2019
- fn refine_mask_returns_pixels_for_ctd_blocks() {
2020
- let mut image = RgbImage::from_pixel(32, 32, Rgb([255, 255, 255]));
2021
- let pred_mask = GrayImage::from_fn(32, 32, |x, y| {
2022
- if (10..22).contains(&x) && (12..18).contains(&y) {
2023
- Luma([255])
2024
- } else {
2025
- Luma([0])
2026
- }
2027
- });
2028
- for y in 12..18 {
2029
- for x in 10..22 {
2030
- image.put_pixel(x, y, Rgb([0, 0, 0]));
2031
- }
2032
- }
2033
- let block = TextBlock {
2034
- x: 10.0,
2035
- y: 12.0,
2036
- width: 12.0,
2037
- height: 6.0,
2038
- line_polygons: Some(vec![[
2039
- [10.0, 12.0],
2040
- [22.0, 12.0],
2041
- [22.0, 18.0],
2042
- [10.0, 18.0],
2043
- ]]),
2044
- source_direction: Some(TextDirection::Horizontal),
2045
- rotation_deg: Some(0.0),
2046
- detected_font_size_px: Some(6.0),
2047
- detector: Some("ctd".to_string()),
2048
- ..Default::default()
2049
- };
2050
-
2051
- let refined = refine_mask(&image, &pred_mask, &[block]);
2052
- assert_eq!(refined.get_pixel(0, 0)[0], 0);
2053
- assert!(refined.get_pixel(16, 15)[0] > 0);
2054
- }
2055
-
2056
- #[test]
2057
- fn refine_segmentation_mask_thresholds_when_blocks_are_missing() {
2058
- let image = DynamicImage::ImageRgb8(RgbImage::from_pixel(16, 16, Rgb([255, 255, 255])));
2059
- let pred_mask = GrayImage::from_fn(16, 16, |x, y| {
2060
- if (4..12).contains(&x) && (5..11).contains(&y) {
2061
- Luma([200])
2062
- } else {
2063
- Luma([0])
2064
- }
2065
- });
2066
-
2067
- let mask = refine_segmentation_mask(&image, &pred_mask, &[]);
2068
- assert_eq!(mask.get_pixel(0, 0)[0], 0);
2069
- assert!(mask.get_pixel(8, 8)[0] > 0);
2070
- }
2071
-
2072
- #[test]
2073
- fn paragraph_merge_joins_stacked_horizontal_blocks() {
2074
- let make_block = |lines: Vec<Quad>| {
2075
- let mut block = CtdBlock {
2076
- bbox: [0.0, 0.0, 0.0, 0.0],
2077
- confidence: 0.9,
2078
- source_language: "unknown".to_string(),
2079
- source_direction: TextDirection::Horizontal,
2080
- lines,
2081
- angle_deg: 0.0,
2082
- detected_font_size_px: 0.0,
2083
- distances: Vec::new(),
2084
- direction_vec: [1.0, 0.0],
2085
- direction_norm: 1.0,
2086
- merged: false,
2087
- };
2088
- block.adjust_bbox(false);
2089
- examine_block(&mut block, 2000, 2000, true);
2090
- block
2091
- };
2092
-
2093
- let blocks = vec![
2094
- make_block(vec![[
2095
- [10.0, 10.0],
2096
- [110.0, 10.0],
2097
- [110.0, 30.0],
2098
- [10.0, 30.0],
2099
- ]]),
2100
- make_block(vec![[
2101
- [12.0, 42.0],
2102
- [108.0, 42.0],
2103
- [108.0, 60.0],
2104
- [12.0, 60.0],
2105
- ]]),
2106
- make_block(vec![
2107
- [[9.0, 74.0], [112.0, 74.0], [112.0, 94.0], [9.0, 94.0]],
2108
- [[8.0, 106.0], [105.0, 106.0], [105.0, 126.0], [8.0, 126.0]],
2109
- ]),
2110
- ];
2111
-
2112
- let merged = merge_paragraph_blocks(blocks, 2000, 2000);
2113
- assert_eq!(merged.len(), 1);
2114
- assert_eq!(merged[0].lines.len(), 4);
2115
- }
2116
-
2117
- #[test]
2118
- fn paragraph_merge_does_not_skip_over_intervening_block() {
2119
- let make_block = |lines: Vec<Quad>| {
2120
- let mut block = CtdBlock {
2121
- bbox: [0.0, 0.0, 0.0, 0.0],
2122
- confidence: 0.9,
2123
- source_language: "unknown".to_string(),
2124
- source_direction: TextDirection::Horizontal,
2125
- lines,
2126
- angle_deg: 0.0,
2127
- detected_font_size_px: 0.0,
2128
- distances: Vec::new(),
2129
- direction_vec: [1.0, 0.0],
2130
- direction_norm: 1.0,
2131
- merged: false,
2132
- };
2133
- block.adjust_bbox(false);
2134
- examine_block(&mut block, 2000, 2000, true);
2135
- block
2136
- };
2137
-
2138
- let top = make_block(vec![[
2139
- [10.0, 10.0],
2140
- [110.0, 10.0],
2141
- [110.0, 28.0],
2142
- [10.0, 28.0],
2143
- ]]);
2144
- let blocker = make_block(vec![[
2145
- [42.0, 32.0],
2146
- [78.0, 32.0],
2147
- [78.0, 76.0],
2148
- [42.0, 76.0],
2149
- ]]);
2150
- let bottom = make_block(vec![[
2151
- [12.0, 38.0],
2152
- [108.0, 38.0],
2153
- [108.0, 56.0],
2154
- [12.0, 56.0],
2155
- ]]);
2156
-
2157
- let merged = merge_paragraph_blocks(vec![top, blocker, bottom], 2000, 2000);
2158
- assert_eq!(merged.len(), 3);
2159
- }
2160
-
2161
- #[test]
2162
- fn sort_regions_stays_stable_across_row_boundaries() {
2163
- let left_lower = test_block([0.0, 9.0, 2.0, 11.0], TextDirection::Horizontal);
2164
- let middle = test_block([1.0, 4.0, 3.0, 6.0], TextDirection::Horizontal);
2165
- let right_upper = test_block([2.0, -1.0, 4.0, 1.0], TextDirection::Horizontal);
2166
- let sorted = sort_regions(vec![left_lower, middle, right_upper]);
2167
- let bboxes = sorted.iter().map(|block| block.bbox).collect::<Vec<_>>();
2168
-
2169
- assert_eq!(
2170
- bboxes,
2171
- vec![
2172
- [0.0, 9.0, 2.0, 11.0],
2173
- [1.0, 4.0, 3.0, 6.0],
2174
- [2.0, -1.0, 4.0, 1.0],
2175
- ]
2176
- );
2177
- }
2178
-
2179
- #[test]
2180
- fn sort_regions_orders_vertical_blocks_right_to_left_then_top_to_bottom() {
2181
- let sorted = sort_regions(vec![
2182
- test_block([10.0, 20.0, 18.0, 28.0], TextDirection::Vertical),
2183
- test_block([30.0, 15.0, 38.0, 23.0], TextDirection::Vertical),
2184
- test_block([30.0, 5.0, 38.0, 13.0], TextDirection::Vertical),
2185
- ]);
2186
- let bboxes = sorted.iter().map(|block| block.bbox).collect::<Vec<_>>();
2187
-
2188
- assert_eq!(
2189
- bboxes,
2190
- vec![
2191
- [30.0, 5.0, 38.0, 13.0],
2192
- [30.0, 15.0, 38.0, 23.0],
2193
- [10.0, 20.0, 18.0, 28.0],
2194
- ]
2195
- );
2196
- }
2197
  }
 
3
  imageops::{self},
4
  };
5
  use imageproc::{
 
 
6
  distance_transform::Norm,
 
7
  geometric_transformations::{Interpolation, Projection, warp_into},
8
+ morphology::dilate,
 
 
9
  };
10
  use koharu_core::{TextBlock, TextDirection};
11
 
 
 
 
 
12
  const FINAL_MASK_DILATE_RADIUS: u8 = 2;
13
 
14
  pub type Quad = [[f32; 2]; 4];
 
22
  pub mask: GrayImage,
23
  }
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  pub fn refine_segmentation_mask(
26
+ _image: &DynamicImage,
27
  pred_mask: &GrayImage,
28
+ _blocks: &[TextBlock],
29
  ) -> GrayImage {
30
+ let base = threshold_binary(pred_mask, 60);
 
 
 
 
 
 
 
 
 
 
31
  dilate(&base, Norm::L1, FINAL_MASK_DILATE_RADIUS)
32
  }
33
 
 
216
  out
217
  }
218
 
219
+ fn clip_quad(quad: &Quad, width: f32, height: f32) -> Quad {
220
+ let mut clipped = *quad;
221
+ for point in &mut clipped {
222
+ point[0] = point[0].clamp(0.0, width);
223
+ point[1] = point[1].clamp(0.0, height);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  }
225
+ clipped
226
  }
227
 
228
+ fn quad_bbox(quad: &Quad) -> [f32; 4] {
 
 
 
 
229
  let mut min_x = f32::MAX;
230
  let mut min_y = f32::MAX;
231
  let mut max_x = f32::MIN;
232
  let mut max_y = f32::MIN;
233
+ for point in quad {
234
  min_x = min_x.min(point[0]);
235
  min_y = min_y.min(point[1]);
236
  max_x = max_x.max(point[0]);
237
  max_y = max_y.max(point[1]);
238
  }
239
+ [min_x, min_y, max_x, max_y]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
  }
241
 
242
+ fn quad_to_tuples(quad: &Quad) -> [(f32, f32); 4] {
243
+ [
244
+ (quad[0][0], quad[0][1]),
245
+ (quad[1][0], quad[1][1]),
246
+ (quad[2][0], quad[2][1]),
247
+ (quad[3][0], quad[3][1]),
248
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  }
250
 
251
+ fn quad_axis_lengths(quad: &Quad) -> (f32, f32) {
252
+ let midpoints = [
253
+ midpoint(quad[0], quad[1]),
254
+ midpoint(quad[1], quad[2]),
255
+ midpoint(quad[2], quad[3]),
256
+ midpoint(quad[3], quad[0]),
257
+ ];
258
+ let vec_v = [
259
+ midpoints[2][0] - midpoints[0][0],
260
+ midpoints[2][1] - midpoints[0][1],
261
+ ];
262
+ let vec_h = [
263
+ midpoints[1][0] - midpoints[3][0],
264
+ midpoints[1][1] - midpoints[3][1],
265
+ ];
266
+ (vector_norm(vec_v), vector_norm(vec_h))
 
267
  }
268
 
269
+ fn midpoint(a: [f32; 2], b: [f32; 2]) -> [f32; 2] {
270
+ [(a[0] + b[0]) * 0.5, (a[1] + b[1]) * 0.5]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
  }
272
 
273
+ fn vector_norm(vector: [f32; 2]) -> f32 {
274
+ (vector[0] * vector[0] + vector[1] * vector[1]).sqrt()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
  }
276
 
277
+ fn threshold_binary(image: &GrayImage, threshold: u8) -> GrayImage {
278
+ GrayImage::from_fn(image.width(), image.height(), |x, y| {
279
+ if image.get_pixel(x, y)[0] > threshold {
280
+ Luma([255u8])
281
+ } else {
282
+ Luma([0u8])
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  }
284
+ })
 
 
 
285
  }
286
 
287
+ #[cfg(test)]
288
+ mod tests {
289
+ use super::*;
 
 
 
 
 
 
 
 
 
 
 
290
 
291
+ #[test]
292
+ fn refine_segmentation_mask_thresholds_when_blocks_are_missing() {
293
+ let image = DynamicImage::ImageRgb8(RgbImage::from_pixel(16, 16, Rgb([255, 255, 255])));
294
+ let pred_mask = GrayImage::from_fn(16, 16, |x, y| {
295
+ if (4..12).contains(&x) && (5..11).contains(&y) {
296
+ Luma([200])
297
+ } else {
298
+ Luma([0])
299
  }
300
+ });
301
+
302
+ let mask = refine_segmentation_mask(&image, &pred_mask, &[]);
303
+ assert_eq!(mask.get_pixel(0, 0)[0], 0);
304
+ assert!(mask.get_pixel(8, 8)[0] > 0);
305
  }
306
 
307
+ #[test]
308
+ fn refine_segmentation_mask_ignores_blocks() {
309
+ let image = DynamicImage::ImageRgb8(RgbImage::from_pixel(32, 32, Rgb([255, 255, 255])));
310
+ let pred_mask = GrayImage::from_fn(32, 32, |x, y| {
311
+ if (8..24).contains(&x) && (10..22).contains(&y) {
312
+ Luma([200])
313
+ } else {
314
+ Luma([0])
315
+ }
316
+ });
317
+ let block = TextBlock {
318
+ x: 10.0,
319
+ y: 11.0,
320
+ width: 8.0,
321
+ height: 6.0,
322
+ line_polygons: Some(vec![[
323
+ [10.0, 11.0],
324
+ [18.0, 11.0],
325
+ [18.0, 17.0],
326
+ [10.0, 17.0],
327
+ ]]),
328
+ source_direction: Some(TextDirection::Horizontal),
329
+ rotation_deg: Some(0.0),
330
+ detected_font_size_px: Some(6.0),
331
+ detector: Some("ctd".to_string()),
332
+ ..Default::default()
333
+ };
334
 
335
+ let with_blocks = refine_segmentation_mask(&image, &pred_mask, &[block]);
336
+ let without_blocks = refine_segmentation_mask(&image, &pred_mask, &[]);
337
+ assert_eq!(with_blocks, without_blocks);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
  }
339
 
340
  #[test]
341
+ fn extract_text_block_regions_falls_back_to_bbox_without_lines() {
342
+ let image = DynamicImage::ImageRgb8(RgbImage::from_pixel(24, 24, Rgb([255, 255, 255])));
343
  let block = TextBlock {
344
  x: 4.0,
345
+ y: 5.0,
346
  width: 10.0,
347
+ height: 8.0,
348
  ..Default::default()
349
  };
350
 
351
  let regions = extract_text_block_regions(&image, &block);
352
  assert_eq!(regions.len(), 1);
353
  assert_eq!(regions[0].width(), 10);
354
+ assert_eq!(regions[0].height(), 8);
355
  }
356
 
357
  #[test]
 
379
  assert!(crop.width() > 12);
380
  assert!(crop.height() > 8);
381
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
  }
koharu-ml/tests/comic_text_detector.rs CHANGED
@@ -14,25 +14,16 @@ async fn comic_text_detector() -> anyhow::Result<()> {
14
  let img = image::open(Path::new(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures/1.jpg"))?;
15
  let detection = model.inference(&img)?;
16
 
17
- assert!(
18
- !detection.text_blocks.is_empty(),
19
- "expected CTD blocks, got line_polygons={}, mask_pixels={}",
20
- detection.line_polygons.len(),
21
- detection.mask.iter().filter(|&&v| v > 0u8).count()
22
- );
23
- assert!(
24
- detection.mask.iter().any(|&v| v > 0u8),
25
- "expected CTD mask, got line_polygons={}",
26
- detection.line_polygons.len()
27
- );
28
  assert_eq!(detection.shrink_map.dimensions(), img.dimensions());
29
  assert_eq!(detection.threshold_map.dimensions(), img.dimensions());
30
- assert!(detection.line_polygons.iter().all(|line| line.len() == 4));
31
  assert!(
32
  detection
33
  .text_blocks
34
  .iter()
35
- .any(|block| block.detector.as_deref() == Some("ctd"))
36
  );
37
 
38
  Ok(())
 
14
  let img = image::open(Path::new(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures/1.jpg"))?;
15
  let detection = model.inference(&img)?;
16
 
17
+ assert!(!detection.text_blocks.is_empty(), "expected CTD boxes");
18
+ assert!(detection.mask.iter().any(|&v| v > 0u8), "expected CTD mask");
 
 
 
 
 
 
 
 
 
19
  assert_eq!(detection.shrink_map.dimensions(), img.dimensions());
20
  assert_eq!(detection.threshold_map.dimensions(), img.dimensions());
21
+ assert!(detection.line_polygons.is_empty());
22
  assert!(
23
  detection
24
  .text_blocks
25
  .iter()
26
+ .all(|block| block.line_polygons.is_none() && block.detector.is_none())
27
  );
28
 
29
  Ok(())
ui/app/(app)/page.tsx CHANGED
@@ -43,7 +43,7 @@ export default function Page() {
43
  onLayoutChanged={onLayoutChanged}
44
  className='flex min-h-0 flex-1'
45
  >
46
- <Panel id='left' defaultSize={220} minSize={160} maxSize={360}>
47
  <Navigator />
48
  </Panel>
49
  <Separator className='bg-border/40 hover:bg-border w-1 transition-colors' />
@@ -56,7 +56,7 @@ export default function Page() {
56
  </AppErrorBoundary>
57
  </Panel>
58
  <Separator className='bg-border/40 hover:bg-border w-1 transition-colors' />
59
- <Panel id='right' defaultSize={320} minSize={320} maxSize={460}>
60
  <AppErrorBoundary>
61
  <Panels />
62
  </AppErrorBoundary>
 
43
  onLayoutChanged={onLayoutChanged}
44
  className='flex min-h-0 flex-1'
45
  >
46
+ <Panel id='left' defaultSize={180} minSize={120} maxSize={300}>
47
  <Navigator />
48
  </Panel>
49
  <Separator className='bg-border/40 hover:bg-border w-1 transition-colors' />
 
56
  </AppErrorBoundary>
57
  </Panel>
58
  <Separator className='bg-border/40 hover:bg-border w-1 transition-colors' />
59
+ <Panel id='right' defaultSize={280} minSize={260} maxSize={400}>
60
  <AppErrorBoundary>
61
  <Panels />
62
  </AppErrorBoundary>
ui/components/Panels.tsx CHANGED
@@ -56,7 +56,10 @@ export function Panels() {
56
  className='min-h-0 flex-1 px-2 pb-2 data-[state=inactive]:hidden'
57
  data-testid='panels-layout'
58
  >
59
- <ScrollArea className='h-full' viewportClassName='pr-1'>
 
 
 
60
  <div className='pt-1'>
61
  <RenderControlsPanel />
62
  </div>
 
56
  className='min-h-0 flex-1 px-2 pb-2 data-[state=inactive]:hidden'
57
  data-testid='panels-layout'
58
  >
59
+ <ScrollArea
60
+ className='h-full'
61
+ viewportClassName='pr-1 [&>div]:!block'
62
+ >
63
  <div className='pt-1'>
64
  <RenderControlsPanel />
65
  </div>
ui/components/canvas/TextBlockLayer.tsx CHANGED
@@ -50,6 +50,14 @@ export function TextBlockLayer({
50
  pointerEvents: 'none',
51
  }}
52
  >
 
 
 
 
 
 
 
 
53
  {textBlocks.map((block, index) => (
54
  <TextBlockItem
55
  key={block.id ?? `fallback-${index}`}
@@ -58,7 +66,6 @@ export function TextBlockLayer({
58
  scale={scale}
59
  selected={index === selectedIndex}
60
  interactive={interactive}
61
- showSprite={showSprites}
62
  onSelect={onSelect}
63
  onUpdate={(updates) => void replaceBlock(index, updates)}
64
  />
@@ -73,7 +80,6 @@ type TextBlockItemProps = {
73
  scale: number
74
  selected: boolean
75
  interactive: boolean
76
- showSprite?: boolean
77
  onSelect: (index: number) => void
78
  onUpdate: (updates: Partial<TextBlock>) => void
79
  }
@@ -93,7 +99,6 @@ function TextBlockItem({
93
  scale,
94
  selected,
95
  interactive,
96
- showSprite,
97
  onSelect,
98
  onUpdate,
99
  }: TextBlockItemProps) {
@@ -213,9 +218,6 @@ function TextBlockItem({
213
  cursor: interactive ? 'move' : 'default',
214
  }}
215
  >
216
- {/* Sprite image at natural size */}
217
- {showSprite && <BlockSprite hash={block.rendered} />}
218
-
219
  {/* Annotation border */}
220
  <div
221
  className={`absolute inset-0 rounded ${
@@ -242,15 +244,23 @@ function TextBlockItem({
242
  )
243
  }
244
 
245
- function BlockSprite({ hash }: { hash?: string }) {
246
- const { data: src } = useBlobImage(hash)
247
  if (!src) return null
 
 
248
  return (
249
  <img
250
  alt=''
251
  src={src}
252
  draggable={false}
253
- className='pointer-events-none absolute top-0 left-0 select-none'
 
 
 
 
 
 
254
  />
255
  )
256
  }
 
50
  pointerEvents: 'none',
51
  }}
52
  >
53
+ {showSprites &&
54
+ textBlocks.map((block, index) => (
55
+ <BlockSprite
56
+ key={`sprite-${block.id ?? index}`}
57
+ block={block}
58
+ scale={scale}
59
+ />
60
+ ))}
61
  {textBlocks.map((block, index) => (
62
  <TextBlockItem
63
  key={block.id ?? `fallback-${index}`}
 
66
  scale={scale}
67
  selected={index === selectedIndex}
68
  interactive={interactive}
 
69
  onSelect={onSelect}
70
  onUpdate={(updates) => void replaceBlock(index, updates)}
71
  />
 
80
  scale: number
81
  selected: boolean
82
  interactive: boolean
 
83
  onSelect: (index: number) => void
84
  onUpdate: (updates: Partial<TextBlock>) => void
85
  }
 
99
  scale,
100
  selected,
101
  interactive,
 
102
  onSelect,
103
  onUpdate,
104
  }: TextBlockItemProps) {
 
218
  cursor: interactive ? 'move' : 'default',
219
  }}
220
  >
 
 
 
221
  {/* Annotation border */}
222
  <div
223
  className={`absolute inset-0 rounded ${
 
244
  )
245
  }
246
 
247
+ function BlockSprite({ block, scale }: { block: TextBlock; scale: number }) {
248
+ const { data: src } = useBlobImage(block.rendered)
249
  if (!src) return null
250
+ const x = (block.renderX ?? block.x) * scale
251
+ const y = (block.renderY ?? block.y) * scale
252
  return (
253
  <img
254
  alt=''
255
  src={src}
256
  draggable={false}
257
+ className='pointer-events-none absolute select-none'
258
+ style={{
259
+ top: 0,
260
+ left: 0,
261
+ transformOrigin: 'top left',
262
+ transform: `translate(${x}px, ${y}px) scale(${scale})`,
263
+ }}
264
  />
265
  )
266
  }
ui/components/panels/RenderControlsPanel.tsx CHANGED
@@ -29,13 +29,7 @@ import {
29
  TooltipContent,
30
  TooltipTrigger,
31
  } from '@/components/ui/tooltip'
32
- import {
33
- Select,
34
- SelectContent,
35
- SelectItem,
36
- SelectTrigger,
37
- SelectValue,
38
- } from '@/components/ui/select'
39
  import { useEditorUiStore } from '@/lib/stores/editorUiStore'
40
  import { usePreferencesStore } from '@/lib/stores/preferencesStore'
41
  import { useListFonts } from '@/lib/api/system/system'
@@ -349,7 +343,8 @@ export function RenderControlsPanel() {
349
  ]
350
 
351
  return (
352
- <div className='flex w-full min-w-0 flex-col gap-1.5'>
 
353
  <div className='flex items-center justify-end'>
354
  <span
355
  data-testid='render-scope-indicator'
@@ -362,16 +357,30 @@ export function RenderControlsPanel() {
362
  </span>
363
  </div>
364
 
365
- <div className='grid w-full min-w-0 grid-cols-[3.5rem_minmax(0,1fr)] items-center gap-1.5'>
366
- <span className='text-muted-foreground text-[10px] font-medium tracking-wide uppercase'>
367
- {fontLabel}
368
- </span>
369
-
 
 
 
 
 
370
  <div className='flex min-w-0 items-center gap-1.5'>
371
  <div className='min-w-0 flex-1'>
372
- <Select
 
373
  value={currentFont}
374
- onValueChange={(value) => {
 
 
 
 
 
 
 
 
375
  const nextFamilies = mergeFontFamilies(
376
  value,
377
  selectedBlock?.style?.fontFamilies,
@@ -390,135 +399,92 @@ export function RenderControlsPanel() {
390
  }))
391
  void updateTextBlocks(nextBlocks)
392
  }}
393
- disabled={fontOptions.length === 0}
394
- >
395
- <SelectTrigger
396
- data-testid='render-font-select'
397
- size='sm'
398
- className='h-7 w-full min-w-0 text-xs'
399
- style={
400
- currentFontFamilyName
401
- ? { fontFamily: currentFontFamilyName }
402
- : undefined
403
- }
404
- >
405
- <SelectValue placeholder={t('render.fontPlaceholder')} />
406
- </SelectTrigger>
407
- <SelectContent position='popper'>
408
- {fontOptions.map((font, index) => (
409
- <SelectItem
410
- key={font.postScriptName}
411
- value={font.postScriptName}
412
- style={{ fontFamily: font.familyName }}
413
- data-testid={`render-font-option-${index}`}
414
- >
415
- {font.familyName}
416
- </SelectItem>
417
- ))}
418
- </SelectContent>
419
- </Select>
420
  </div>
421
-
422
- <Tooltip>
423
- <TooltipTrigger asChild>
424
- <div>
425
- <ColorPicker
426
- value={currentColorHex}
427
- disabled={!hasBlocks}
428
- triggerTestId='render-color-trigger'
429
- pickerTestId='render-color-picker'
430
- swatchTestId='render-color-swatch'
431
- inputTestId='render-color-input'
432
- pickButtonTestId='render-color-pick'
433
- onChange={(hex) => {
434
- const nextColor = hexToColor(hex, currentColor[3] ?? 255)
435
- if (applyStyleToSelected({ color: nextColor })) return
436
- applyStyleToAll({ color: nextColor })
437
- }}
438
- className='h-7 w-7'
439
- />
440
- </div>
441
- </TooltipTrigger>
442
- <TooltipContent side='bottom' sideOffset={4}>
443
- {t('render.fontColorLabel')}
444
- </TooltipContent>
445
- </Tooltip>
446
  </div>
447
  </div>
448
 
449
- <div className='grid w-full min-w-0 grid-cols-[3.5rem_minmax(0,1fr)] items-center gap-1.5'>
450
- <span className='text-muted-foreground text-[10px] font-medium tracking-wide uppercase'>
 
451
  {fontSizeLabel}
452
  </span>
453
-
454
- <div className='flex min-w-0 items-center gap-1'>
455
- <div className='border-input bg-background flex w-auto min-w-0 shrink-0 items-center rounded-md border shadow-xs'>
456
- <Button
457
- type='button'
458
- variant='ghost'
459
- size='icon-sm'
460
- aria-label={`${fontSizeLabel} -`}
461
- className='size-7 rounded-r-none border-r'
462
- disabled={selectedBlockIndex === undefined}
463
- onClick={() => {
464
- const next = Math.max(
465
- 6,
466
- Math.round((currentFontSize ?? 16) - 1),
467
- )
468
- applyStyleToSelected({ fontSize: next })
469
- }}
470
- >
471
- <MinusIcon className='size-3' />
472
- </Button>
473
-
474
- <Input
475
- type='number'
476
- step='1'
477
- min='6'
478
- max='300'
479
- inputMode='numeric'
480
- className='h-7 w-14 min-w-0 [appearance:textfield] rounded-none border-0 px-1.5 text-center text-[11px] shadow-none focus-visible:ring-0 [&::-webkit-inner-spin-button]:appearance-none [&::-webkit-outer-spin-button]:appearance-none'
481
- data-testid='render-font-size'
482
- disabled={selectedBlockIndex === undefined}
483
- value={
484
- currentFontSize !== undefined ? Math.round(currentFontSize) : ''
485
- }
486
- placeholder='auto'
487
- onChange={(event) => {
488
- const parsed = Number.parseInt(event.target.value, 10)
489
- if (!Number.isFinite(parsed) || parsed < 1) return
490
- applyStyleToSelected({ fontSize: Math.min(300, parsed) })
491
- }}
492
- />
493
-
494
- <Button
495
- type='button'
496
- variant='ghost'
497
- size='icon-sm'
498
- aria-label={`${fontSizeLabel} +`}
499
- className='size-7 rounded-l-none border-l'
500
- disabled={selectedBlockIndex === undefined}
501
- onClick={() => {
502
- const next = Math.min(
503
- 300,
504
- Math.round((currentFontSize ?? 16) + 1),
505
- )
506
- applyStyleToSelected({ fontSize: next })
507
- }}
508
- >
509
- <PlusIcon className='size-3' />
510
- </Button>
511
- </div>
512
- <span className='text-muted-foreground text-[10px]'>px</span>
513
- </div>
514
- </div>
515
-
516
- <div className='grid w-full min-w-0 grid-cols-[3.5rem_minmax(0,1fr)] items-center gap-1.5'>
517
- <span className='text-muted-foreground text-[10px] font-medium tracking-wide uppercase'>
518
  {effectLabel}
519
  </span>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
520
 
521
- <div className='flex min-w-0 flex-wrap items-center gap-1'>
522
  {effectItems.map((item) => {
523
  const active = currentEffect[item.key]
524
  const Icon = item.Icon
@@ -531,7 +497,7 @@ export function RenderControlsPanel() {
531
  aria-label={item.label}
532
  data-testid={`render-effect-toggle-${item.key}`}
533
  className={cn(
534
- 'size-7',
535
  active &&
536
  'bg-primary text-primary-foreground border-primary hover:bg-primary/90',
537
  )}
@@ -555,14 +521,8 @@ export function RenderControlsPanel() {
555
  )
556
  })}
557
  </div>
558
- </div>
559
-
560
- <div className='grid w-full min-w-0 grid-cols-[3.5rem_minmax(0,1fr)] items-center gap-1.5'>
561
- <span className='text-muted-foreground text-[10px] font-medium tracking-wide uppercase'>
562
- {alignLabel}
563
- </span>
564
 
565
- <div className='flex min-w-0 flex-wrap items-center gap-1'>
566
  {textAlignItems.map((item) => {
567
  const active = currentTextAlign === item.value
568
  const Icon = item.Icon
@@ -576,7 +536,7 @@ export function RenderControlsPanel() {
576
  data-testid={`render-align-${item.value}`}
577
  disabled={!hasBlocks}
578
  className={cn(
579
- 'size-7',
580
  active &&
581
  'bg-primary text-primary-foreground border-primary hover:bg-primary/90',
582
  )}
@@ -598,12 +558,12 @@ export function RenderControlsPanel() {
598
  </div>
599
  </div>
600
 
601
- <div className='grid w-full min-w-0 grid-cols-[3.5rem_minmax(0,1fr)] items-center gap-1.5'>
602
- <span className='text-muted-foreground text-[10px] font-medium tracking-wide uppercase'>
 
603
  {strokeLabel}
604
  </span>
605
-
606
- <div className='flex min-w-0 flex-wrap items-center gap-1'>
607
  <Tooltip>
608
  <TooltipTrigger asChild>
609
  <Button
@@ -648,7 +608,7 @@ export function RenderControlsPanel() {
648
  color: hexToColor(hex, currentStroke.color[3] ?? 255),
649
  })
650
  }}
651
- className='h-7 w-7'
652
  />
653
  </div>
654
  </TooltipTrigger>
@@ -657,60 +617,53 @@ export function RenderControlsPanel() {
657
  </TooltipContent>
658
  </Tooltip>
659
 
660
- <Tooltip>
661
- <TooltipTrigger asChild>
662
- <div className='border-input bg-background flex w-auto min-w-0 shrink-0 items-center rounded-md border shadow-xs'>
663
- <Button
664
- type='button'
665
- variant='ghost'
666
- size='icon-sm'
667
- aria-label={`${strokeWidthLabel} -`}
668
- className='size-7 rounded-r-none border-r'
669
- onClick={() =>
670
- updateStrokeWidth(currentStrokeWidth - STROKE_WIDTH_STEP)
671
- }
672
- >
673
- <MinusIcon className='size-3' />
674
- </Button>
675
-
676
- <Input
677
- type='number'
678
- step={String(STROKE_WIDTH_STEP)}
679
- min={String(MIN_STROKE_WIDTH)}
680
- max={String(MAX_STROKE_WIDTH)}
681
- inputMode='decimal'
682
- className='h-7 w-14 min-w-0 [appearance:textfield] rounded-none border-0 px-1.5 text-center text-[11px] shadow-none focus-visible:ring-0 [&::-webkit-inner-spin-button]:appearance-none [&::-webkit-outer-spin-button]:appearance-none'
683
- data-testid='render-stroke-width'
684
- value={
685
- Number.isFinite(currentStrokeWidth)
686
- ? currentStrokeWidth
687
- : DEFAULT_STROKE_WIDTH
688
- }
689
- onChange={(event) => {
690
- const parsed = Number.parseFloat(event.target.value)
691
- if (!Number.isFinite(parsed)) return
692
- updateStrokeWidth(parsed)
693
- }}
694
- />
695
 
696
- <Button
697
- type='button'
698
- variant='ghost'
699
- size='icon-sm'
700
- aria-label={`${strokeWidthLabel} +`}
701
- className='size-7 rounded-l-none border-l'
702
- onClick={() =>
703
- updateStrokeWidth(currentStrokeWidth + STROKE_WIDTH_STEP)
704
- }
705
- >
706
- <PlusIcon className='size-3' />
707
- </Button>
708
- </div>
709
- </TooltipTrigger>
710
- <TooltipContent side='bottom' sideOffset={4}>
711
- {strokeWidthLabel}
712
- </TooltipContent>
713
- </Tooltip>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
714
  </div>
715
  </div>
716
  </div>
 
29
  TooltipContent,
30
  TooltipTrigger,
31
  } from '@/components/ui/tooltip'
32
+ import { FontSelect } from '@/components/ui/font-select'
 
 
 
 
 
 
33
  import { useEditorUiStore } from '@/lib/stores/editorUiStore'
34
  import { usePreferencesStore } from '@/lib/stores/preferencesStore'
35
  import { useListFonts } from '@/lib/api/system/system'
 
343
  ]
344
 
345
  return (
346
+ <div className='flex w-full min-w-0 flex-col gap-2'>
347
+ {/* Scope indicator */}
348
  <div className='flex items-center justify-end'>
349
  <span
350
  data-testid='render-scope-indicator'
 
357
  </span>
358
  </div>
359
 
360
+ {/* Font + Color */}
361
+ <div className='flex flex-col gap-0.5'>
362
+ <div className='flex items-baseline justify-between'>
363
+ <span className='text-muted-foreground text-[10px] font-medium uppercase'>
364
+ {fontLabel}
365
+ </span>
366
+ <span className='text-muted-foreground text-[10px] font-medium uppercase'>
367
+ {t('render.fontColorLabel')}
368
+ </span>
369
+ </div>
370
  <div className='flex min-w-0 items-center gap-1.5'>
371
  <div className='min-w-0 flex-1'>
372
+ <FontSelect
373
+ data-testid='render-font-select'
374
  value={currentFont}
375
+ options={fontOptions}
376
+ disabled={fontOptions.length === 0}
377
+ placeholder={t('render.fontPlaceholder')}
378
+ triggerStyle={
379
+ currentFontFamilyName
380
+ ? { fontFamily: currentFontFamilyName }
381
+ : undefined
382
+ }
383
+ onChange={(value) => {
384
  const nextFamilies = mergeFontFamilies(
385
  value,
386
  selectedBlock?.style?.fontFamilies,
 
399
  }))
400
  void updateTextBlocks(nextBlocks)
401
  }}
402
+ />
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
403
  </div>
404
+ <ColorPicker
405
+ value={currentColorHex}
406
+ disabled={!hasBlocks}
407
+ triggerTestId='render-color-trigger'
408
+ pickerTestId='render-color-picker'
409
+ swatchTestId='render-color-swatch'
410
+ inputTestId='render-color-input'
411
+ pickButtonTestId='render-color-pick'
412
+ onChange={(hex) => {
413
+ const nextColor = hexToColor(hex, currentColor[3] ?? 255)
414
+ if (applyStyleToSelected({ color: nextColor })) return
415
+ applyStyleToAll({ color: nextColor })
416
+ }}
417
+ className='size-7'
418
+ />
 
 
 
 
 
 
 
 
 
 
419
  </div>
420
  </div>
421
 
422
+ {/* Size / Effect / Align */}
423
+ <div className='grid w-full grid-cols-[minmax(0,1fr)_auto_auto] items-end gap-x-2'>
424
+ <span className='text-muted-foreground text-[10px] font-medium uppercase'>
425
  {fontSizeLabel}
426
  </span>
427
+ <span className='text-muted-foreground text-[10px] font-medium uppercase'>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
428
  {effectLabel}
429
  </span>
430
+ <span className='text-muted-foreground text-[10px] font-medium uppercase'>
431
+ {alignLabel}
432
+ </span>
433
+
434
+ <div className='border-input bg-background flex min-w-0 items-center rounded-md border shadow-xs'>
435
+ <Button
436
+ type='button'
437
+ variant='ghost'
438
+ size='icon-sm'
439
+ aria-label={`${fontSizeLabel} -`}
440
+ className='size-7 shrink-0 rounded-r-none border-r'
441
+ disabled={selectedBlockIndex === undefined}
442
+ onClick={() => {
443
+ const next = Math.max(6, Math.round((currentFontSize ?? 16) - 1))
444
+ applyStyleToSelected({ fontSize: next })
445
+ }}
446
+ >
447
+ <MinusIcon className='size-3' />
448
+ </Button>
449
+ <Input
450
+ type='number'
451
+ step='1'
452
+ min='6'
453
+ max='300'
454
+ inputMode='numeric'
455
+ className='h-7 min-w-0 flex-1 [appearance:textfield] rounded-none border-0 px-1 text-center text-xs shadow-none focus-visible:ring-0 [&::-webkit-inner-spin-button]:appearance-none [&::-webkit-outer-spin-button]:appearance-none'
456
+ data-testid='render-font-size'
457
+ disabled={selectedBlockIndex === undefined}
458
+ value={
459
+ currentFontSize !== undefined ? Math.round(currentFontSize) : ''
460
+ }
461
+ placeholder='auto'
462
+ onChange={(event) => {
463
+ const parsed = Number.parseInt(event.target.value, 10)
464
+ if (!Number.isFinite(parsed) || parsed < 1) return
465
+ applyStyleToSelected({ fontSize: Math.min(300, parsed) })
466
+ }}
467
+ />
468
+ <Button
469
+ type='button'
470
+ variant='ghost'
471
+ size='icon-sm'
472
+ aria-label={`${fontSizeLabel} +`}
473
+ className='size-7 shrink-0 rounded-l-none border-l'
474
+ disabled={selectedBlockIndex === undefined}
475
+ onClick={() => {
476
+ const next = Math.min(
477
+ 300,
478
+ Math.round((currentFontSize ?? 16) + 1),
479
+ )
480
+ applyStyleToSelected({ fontSize: next })
481
+ }}
482
+ >
483
+ <PlusIcon className='size-3' />
484
+ </Button>
485
+ </div>
486
 
487
+ <div className='flex items-center gap-1'>
488
  {effectItems.map((item) => {
489
  const active = currentEffect[item.key]
490
  const Icon = item.Icon
 
497
  aria-label={item.label}
498
  data-testid={`render-effect-toggle-${item.key}`}
499
  className={cn(
500
+ 'size-7 shrink-0',
501
  active &&
502
  'bg-primary text-primary-foreground border-primary hover:bg-primary/90',
503
  )}
 
521
  )
522
  })}
523
  </div>
 
 
 
 
 
 
524
 
525
+ <div className='flex items-center gap-1'>
526
  {textAlignItems.map((item) => {
527
  const active = currentTextAlign === item.value
528
  const Icon = item.Icon
 
536
  data-testid={`render-align-${item.value}`}
537
  disabled={!hasBlocks}
538
  className={cn(
539
+ 'size-7 shrink-0',
540
  active &&
541
  'bg-primary text-primary-foreground border-primary hover:bg-primary/90',
542
  )}
 
558
  </div>
559
  </div>
560
 
561
+ {/* Border / Stroke */}
562
+ <div className='flex flex-col gap-0.5'>
563
+ <span className='text-muted-foreground text-[10px] font-medium uppercase'>
564
  {strokeLabel}
565
  </span>
566
+ <div className='flex min-w-0 items-center gap-1'>
 
567
  <Tooltip>
568
  <TooltipTrigger asChild>
569
  <Button
 
608
  color: hexToColor(hex, currentStroke.color[3] ?? 255),
609
  })
610
  }}
611
+ className='size-7'
612
  />
613
  </div>
614
  </TooltipTrigger>
 
617
  </TooltipContent>
618
  </Tooltip>
619
 
620
+ <div className='border-input bg-background flex min-w-0 flex-1 items-center rounded-md border shadow-xs'>
621
+ <Button
622
+ type='button'
623
+ variant='ghost'
624
+ size='icon-sm'
625
+ aria-label={`${strokeWidthLabel} -`}
626
+ className='size-7 shrink-0 rounded-r-none border-r'
627
+ onClick={() =>
628
+ updateStrokeWidth(currentStrokeWidth - STROKE_WIDTH_STEP)
629
+ }
630
+ >
631
+ <MinusIcon className='size-3' />
632
+ </Button>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
633
 
634
+ <Input
635
+ type='number'
636
+ step={String(STROKE_WIDTH_STEP)}
637
+ min={String(MIN_STROKE_WIDTH)}
638
+ max={String(MAX_STROKE_WIDTH)}
639
+ inputMode='decimal'
640
+ className='h-7 min-w-0 flex-1 [appearance:textfield] rounded-none border-0 px-1 text-center text-xs shadow-none focus-visible:ring-0 [&::-webkit-inner-spin-button]:appearance-none [&::-webkit-outer-spin-button]:appearance-none'
641
+ data-testid='render-stroke-width'
642
+ value={
643
+ Number.isFinite(currentStrokeWidth)
644
+ ? currentStrokeWidth
645
+ : DEFAULT_STROKE_WIDTH
646
+ }
647
+ onChange={(event) => {
648
+ const parsed = Number.parseFloat(event.target.value)
649
+ if (!Number.isFinite(parsed)) return
650
+ updateStrokeWidth(parsed)
651
+ }}
652
+ />
653
+
654
+ <Button
655
+ type='button'
656
+ variant='ghost'
657
+ size='icon-sm'
658
+ aria-label={`${strokeWidthLabel} +`}
659
+ className='size-7 shrink-0 rounded-l-none border-l'
660
+ onClick={() =>
661
+ updateStrokeWidth(currentStrokeWidth + STROKE_WIDTH_STEP)
662
+ }
663
+ >
664
+ <PlusIcon className='size-3' />
665
+ </Button>
666
+ </div>
667
  </div>
668
  </div>
669
  </div>
ui/components/ui/color-picker.tsx CHANGED
@@ -1,5 +1,6 @@
1
  'use client'
2
 
 
3
  import { HexColorInput, HexColorPicker } from 'react-colorful'
4
  import { Button } from '@/components/ui/button'
5
  import {
@@ -43,6 +44,21 @@ export function ColorPicker({
43
  inputTestId,
44
  pickButtonTestId,
45
  }: ColorPickerProps) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  const canUseEyeDropper =
47
  typeof window !== 'undefined' &&
48
  typeof (window as EyeDropperWindow).EyeDropper === 'function'
@@ -54,7 +70,9 @@ export function ColorPicker({
54
  try {
55
  const eyeDropper = new EyeDropperCtor()
56
  const result = await eyeDropper.open()
57
- onChange(normalizeHex(result.sRGBHex))
 
 
58
  } catch (error) {
59
  const maybeDomException = error as DOMException | undefined
60
  if (maybeDomException?.name === 'AbortError') return
@@ -76,29 +94,38 @@ export function ColorPicker({
76
  <div
77
  data-testid={swatchTestId}
78
  className='size-4 rounded-sm'
79
- style={{ backgroundColor: value }}
80
  />
81
  </button>
82
  </PopoverTrigger>
83
  <PopoverContent className='w-64 p-3' sideOffset={8}>
84
  <div className='space-y-3'>
85
- <div data-testid={pickerTestId}>
 
86
  <HexColorPicker
87
- color={value}
88
- onChange={(color) => onChange(normalizeHex(color))}
 
 
 
 
89
  />
90
  </div>
91
 
92
  <div className='flex items-center gap-2'>
93
  <HexColorInput
94
- color={value}
95
  prefixed
96
  data-testid={inputTestId}
97
  spellCheck={false}
98
  disabled={disabled}
99
  aria-label='Hex color code'
100
  className='border-input bg-background focus-visible:border-ring focus-visible:ring-ring/50 h-8 min-w-0 flex-1 rounded-md border px-2 font-mono text-xs uppercase shadow-xs transition outline-none focus-visible:ring-[3px]'
101
- onChange={(color) => onChange(normalizeHex(color))}
 
 
 
 
102
  />
103
 
104
  {canUseEyeDropper && (
 
1
  'use client'
2
 
3
+ import { useRef, useState, useCallback, useEffect } from 'react'
4
  import { HexColorInput, HexColorPicker } from 'react-colorful'
5
  import { Button } from '@/components/ui/button'
6
  import {
 
44
  inputTestId,
45
  pickButtonTestId,
46
  }: ColorPickerProps) {
47
+ const [localColor, setLocalColor] = useState(value)
48
+ const dragging = useRef(false)
49
+
50
+ // Sync external value when not dragging
51
+ useEffect(() => {
52
+ if (!dragging.current) setLocalColor(value)
53
+ }, [value])
54
+
55
+ const handlePointerUp = useCallback(() => {
56
+ if (dragging.current) {
57
+ dragging.current = false
58
+ onChange(localColor)
59
+ }
60
+ }, [localColor, onChange])
61
+
62
  const canUseEyeDropper =
63
  typeof window !== 'undefined' &&
64
  typeof (window as EyeDropperWindow).EyeDropper === 'function'
 
70
  try {
71
  const eyeDropper = new EyeDropperCtor()
72
  const result = await eyeDropper.open()
73
+ const color = normalizeHex(result.sRGBHex)
74
+ setLocalColor(color)
75
+ onChange(color)
76
  } catch (error) {
77
  const maybeDomException = error as DOMException | undefined
78
  if (maybeDomException?.name === 'AbortError') return
 
94
  <div
95
  data-testid={swatchTestId}
96
  className='size-4 rounded-sm'
97
+ style={{ backgroundColor: localColor }}
98
  />
99
  </button>
100
  </PopoverTrigger>
101
  <PopoverContent className='w-64 p-3' sideOffset={8}>
102
  <div className='space-y-3'>
103
+ {/* eslint-disable-next-line jsx-a11y/no-static-element-interactions */}
104
+ <div data-testid={pickerTestId} onPointerUp={handlePointerUp}>
105
  <HexColorPicker
106
+ color={localColor}
107
+ onChange={(color) => {
108
+ const normalized = normalizeHex(color)
109
+ dragging.current = true
110
+ setLocalColor(normalized)
111
+ }}
112
  />
113
  </div>
114
 
115
  <div className='flex items-center gap-2'>
116
  <HexColorInput
117
+ color={localColor}
118
  prefixed
119
  data-testid={inputTestId}
120
  spellCheck={false}
121
  disabled={disabled}
122
  aria-label='Hex color code'
123
  className='border-input bg-background focus-visible:border-ring focus-visible:ring-ring/50 h-8 min-w-0 flex-1 rounded-md border px-2 font-mono text-xs uppercase shadow-xs transition outline-none focus-visible:ring-[3px]'
124
+ onChange={(color) => {
125
+ const normalized = normalizeHex(color)
126
+ setLocalColor(normalized)
127
+ onChange(normalized)
128
+ }}
129
  />
130
 
131
  {canUseEyeDropper && (
ui/components/ui/font-select.tsx ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 'use client'
2
+
3
+ import { useRef, useState, useMemo, useCallback } from 'react'
4
+ import { useVirtualizer } from '@tanstack/react-virtual'
5
+ import { CheckIcon, ChevronDownIcon } from 'lucide-react'
6
+ import {
7
+ Popover,
8
+ PopoverContent,
9
+ PopoverTrigger,
10
+ } from '@/components/ui/popover'
11
+ import { ScrollArea } from '@/components/ui/scroll-area'
12
+ import { cn } from '@/lib/utils'
13
+
14
+ const ITEM_HEIGHT = 28
15
+ const MAX_VISIBLE = 10
16
+
17
+ type FontOption = {
18
+ familyName: string
19
+ postScriptName: string
20
+ }
21
+
22
+ type FontSelectProps = {
23
+ value: string
24
+ options: FontOption[]
25
+ disabled?: boolean
26
+ placeholder?: string
27
+ className?: string
28
+ triggerClassName?: string
29
+ triggerStyle?: React.CSSProperties
30
+ onChange: (value: string) => void
31
+ 'data-testid'?: string
32
+ }
33
+
34
+ export function FontSelect({
35
+ value,
36
+ options,
37
+ disabled,
38
+ placeholder,
39
+ className,
40
+ triggerClassName,
41
+ triggerStyle,
42
+ onChange,
43
+ ...props
44
+ }: FontSelectProps) {
45
+ const [open, setOpen] = useState(false)
46
+ const [search, setSearch] = useState('')
47
+ const scrollRef = useRef<HTMLDivElement>(null)
48
+ const inputRef = useRef<HTMLInputElement>(null)
49
+
50
+ const filtered = useMemo(() => {
51
+ if (!search) return options
52
+ const lower = search.toLowerCase()
53
+ return options.filter((f) => f.familyName.toLowerCase().includes(lower))
54
+ }, [options, search])
55
+
56
+ const virtualizer = useVirtualizer({
57
+ count: filtered.length,
58
+ getScrollElement: () => scrollRef.current,
59
+ estimateSize: () => ITEM_HEIGHT,
60
+ overscan: 5,
61
+ enabled: open,
62
+ })
63
+
64
+ const viewportRef = useCallback(
65
+ (node: HTMLDivElement | null) => {
66
+ scrollRef.current = node
67
+ if (node) virtualizer.measure()
68
+ },
69
+ // eslint-disable-next-line react-hooks/exhaustive-deps
70
+ [open],
71
+ )
72
+
73
+ const selectedLabel = options.find(
74
+ (f) => f.postScriptName === value || f.familyName === value,
75
+ )?.familyName
76
+
77
+ const listHeight = Math.min(filtered.length, MAX_VISIBLE) * ITEM_HEIGHT
78
+
79
+ return (
80
+ <Popover
81
+ open={open}
82
+ onOpenChange={(next) => {
83
+ setOpen(next)
84
+ if (!next) setSearch('')
85
+ }}
86
+ >
87
+ <PopoverTrigger
88
+ disabled={disabled}
89
+ data-testid={props['data-testid']}
90
+ className={cn(
91
+ "border-input data-[placeholder]:text-muted-foreground [&_svg:not([class*='text-'])]:text-muted-foreground focus-visible:border-ring focus-visible:ring-ring/50 dark:bg-input/30 dark:hover:bg-input/50 flex h-7 w-full items-center justify-between gap-1.5 rounded-md border bg-transparent px-2 py-1 text-xs whitespace-nowrap shadow-xs transition-[color,box-shadow] outline-none focus-visible:ring-[3px] disabled:cursor-not-allowed disabled:opacity-50",
92
+ triggerClassName,
93
+ )}
94
+ style={triggerStyle}
95
+ >
96
+ <span className='truncate'>{selectedLabel ?? placeholder ?? ''}</span>
97
+ <ChevronDownIcon className='size-3.5 shrink-0 opacity-50' />
98
+ </PopoverTrigger>
99
+ <PopoverContent
100
+ className={cn('w-(--radix-popover-trigger-width) p-0', className)}
101
+ align='start'
102
+ onOpenAutoFocus={(e) => {
103
+ e.preventDefault()
104
+ inputRef.current?.focus()
105
+ }}
106
+ >
107
+ <input
108
+ ref={inputRef}
109
+ value={search}
110
+ onChange={(e) => setSearch(e.target.value)}
111
+ placeholder='Search fonts…'
112
+ className='placeholder:text-muted-foreground w-full border-b bg-transparent px-2 py-1.5 text-xs outline-none'
113
+ />
114
+ <ScrollArea
115
+ className='relative'
116
+ style={{ height: listHeight }}
117
+ viewportRef={viewportRef}
118
+ >
119
+ <div
120
+ style={{
121
+ height: virtualizer.getTotalSize(),
122
+ position: 'relative',
123
+ }}
124
+ >
125
+ {virtualizer.getVirtualItems().map((vi) => {
126
+ const font = filtered[vi.index]
127
+ const selected =
128
+ font.postScriptName === value || font.familyName === value
129
+ return (
130
+ <button
131
+ key={vi.key}
132
+ type='button'
133
+ className={cn(
134
+ 'hover:bg-accent hover:text-accent-foreground absolute left-0 flex w-full cursor-default items-center gap-1.5 rounded-sm px-2 text-xs select-none',
135
+ selected && 'bg-accent',
136
+ )}
137
+ style={{
138
+ height: ITEM_HEIGHT,
139
+ top: vi.start,
140
+ fontFamily: font.familyName,
141
+ }}
142
+ onClick={() => {
143
+ onChange(font.postScriptName)
144
+ setOpen(false)
145
+ setSearch('')
146
+ }}
147
+ >
148
+ <span className='flex size-3 shrink-0 items-center justify-center'>
149
+ {selected && <CheckIcon className='size-3' />}
150
+ </span>
151
+ <span className='truncate'>{font.familyName}</span>
152
+ </button>
153
+ )
154
+ })}
155
+ </div>
156
+ </ScrollArea>
157
+ {filtered.length === 0 && (
158
+ <div className='text-muted-foreground px-2 py-4 text-center text-xs'>
159
+ No fonts found
160
+ </div>
161
+ )}
162
+ </PopoverContent>
163
+ </Popover>
164
+ )
165
+ }
ui/hooks/useBlobData.ts CHANGED
@@ -23,7 +23,7 @@ export function useBlobData(hash: string | undefined): Uint8Array | undefined {
23
  enabled: !!hash,
24
  placeholderData: keepPreviousData,
25
  })
26
- return data
27
  }
28
 
29
  const blobImageQueryOptions = (hash: string) => ({
 
23
  enabled: !!hash,
24
  placeholderData: keepPreviousData,
25
  })
26
+ return hash ? data : undefined
27
  }
28
 
29
  const blobImageQueryOptions = (hash: string) => ({
ui/hooks/useTextBlocks.ts CHANGED
@@ -44,6 +44,10 @@ const mapTextBlock = (
44
  style: block.style as TextBlock['style'],
45
  fontPrediction: block.fontPrediction as TextBlock['fontPrediction'],
46
  rendered: block.rendered ?? undefined,
 
 
 
 
47
  })
48
 
49
  export type MappedDocument = {
 
44
  style: block.style as TextBlock['style'],
45
  fontPrediction: block.fontPrediction as TextBlock['fontPrediction'],
46
  rendered: block.rendered ?? undefined,
47
+ renderX: block.renderX ?? undefined,
48
+ renderY: block.renderY ?? undefined,
49
+ renderWidth: block.renderWidth ?? undefined,
50
+ renderHeight: block.renderHeight ?? undefined,
51
  })
52
 
53
  export type MappedDocument = {
ui/lib/api/schemas/textBlockDetail.ts CHANGED
@@ -18,6 +18,17 @@ export interface TextBlockDetail {
18
  id: string
19
  /** @nullable */
20
  linePolygons?: number[][][] | null
 
 
 
 
 
 
 
 
 
 
 
21
  /**
22
  * Blob hash for the rendered text block sprite.
23
  * @nullable
 
18
  id: string
19
  /** @nullable */
20
  linePolygons?: number[][][] | null
21
+ /** @nullable */
22
+ renderHeight?: number | null
23
+ /** @nullable */
24
+ renderWidth?: number | null
25
+ /**
26
+ * Actual render area position/size (when bubble expansion is used).
27
+ * @nullable
28
+ */
29
+ renderX?: number | null
30
+ /** @nullable */
31
+ renderY?: number | null
32
  /**
33
  * Blob hash for the rendered text block sprite.
34
  * @nullable
ui/openapi.json CHANGED
@@ -2381,6 +2381,23 @@
2381
  }
2382
  }
2383
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2384
  "rendered": {
2385
  "type": ["string", "null"],
2386
  "description": "Blob hash for the rendered text block sprite."
 
2381
  }
2382
  }
2383
  },
2384
+ "renderHeight": {
2385
+ "type": ["number", "null"],
2386
+ "format": "float"
2387
+ },
2388
+ "renderWidth": {
2389
+ "type": ["number", "null"],
2390
+ "format": "float"
2391
+ },
2392
+ "renderX": {
2393
+ "type": ["number", "null"],
2394
+ "format": "float",
2395
+ "description": "Actual render area position/size (when bubble expansion is used)."
2396
+ },
2397
+ "renderY": {
2398
+ "type": ["number", "null"],
2399
+ "format": "float"
2400
+ },
2401
  "rendered": {
2402
  "type": ["string", "null"],
2403
  "description": "Blob hash for the rendered text block sprite."
ui/types.d.ts CHANGED
@@ -69,6 +69,11 @@ export type TextBlock = {
69
  fontPrediction?: FontPrediction
70
  /** Blob hash for the rendered text block sprite. */
71
  rendered?: string
 
 
 
 
 
72
  }
73
 
74
  export type ToolMode = 'select' | 'block' | 'brush' | 'repairBrush' | 'eraser'
 
69
  fontPrediction?: FontPrediction
70
  /** Blob hash for the rendered text block sprite. */
71
  rendered?: string
72
+ /** Actual render area (when bubble expansion is used). */
73
+ renderX?: number
74
+ renderY?: number
75
+ renderWidth?: number
76
+ renderHeight?: number
77
  }
78
 
79
  export type ToolMode = 'select' | 'block' | 'brush' | 'repairBrush' | 'eraser'