Mayo commited on
Commit
aa8927a
·
unverified ·
1 Parent(s): 6911069

feat: anime-text model

Browse files
README.md CHANGED
@@ -186,6 +186,7 @@ Koharu uses multiple pretrained models, each tuned for a specific part of the pa
186
 
187
  These models find text regions, speech bubbles, and page structure.
188
 
 
189
  - [comic-text-bubble-detector](https://huggingface.co/ogkalu/comic-text-and-bubble-detector) for joint text block and speech bubble detection
190
  - [comic-text-detector](https://huggingface.co/mayocream/comic-text-detector) for text segmentation masks
191
  - [PP-DocLayoutV3](https://huggingface.co/PaddlePaddle/PP-DocLayoutV3_safetensors) for document layout analysis
 
186
 
187
  These models find text regions, speech bubbles, and page structure.
188
 
189
+ - [anime-text-yolo](https://huggingface.co/mayocream/anime-text-yolo) for text block detection
190
  - [comic-text-bubble-detector](https://huggingface.co/ogkalu/comic-text-and-bubble-detector) for joint text block and speech bubble detection
191
  - [comic-text-detector](https://huggingface.co/mayocream/comic-text-detector) for text segmentation masks
192
  - [PP-DocLayoutV3](https://huggingface.co/PaddlePaddle/PP-DocLayoutV3_safetensors) for document layout analysis
koharu-app/src/config.rs CHANGED
@@ -83,7 +83,7 @@ pub struct PipelineConfig {
83
  impl Default for PipelineConfig {
84
  fn default() -> Self {
85
  Self {
86
- detector: "pp-doclayout-v3".to_string(),
87
  font_detector: "yuzumarker-font-detection".to_string(),
88
  segmenter: "comic-text-detector-seg".to_string(),
89
  bubble_segmenter: "speech-bubble-segmentation".to_string(),
 
83
  impl Default for PipelineConfig {
84
  fn default() -> Self {
85
  Self {
86
+ detector: "anime-text".to_string(),
87
  font_detector: "yuzumarker-font-detection".to_string(),
88
  segmenter: "comic-text-detector-seg".to_string(),
89
  bubble_segmenter: "speech-bubble-segmentation".to_string(),
koharu-app/src/pipeline/engines/anime_text.rs ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //! Anime Text YOLO detector. Emits `AddNode` ops for each detected text region.
2
+
3
+ use anyhow::Result;
4
+ use async_trait::async_trait;
5
+ use koharu_core::{Op, TextData};
6
+ use koharu_ml::anime_text::AnimeTextDetector;
7
+
8
+ use crate::pipeline::artifacts::Artifact;
9
+ use crate::pipeline::engine::{Engine, EngineCtx, EngineInfo};
10
+ use crate::pipeline::engines::support::{
11
+ clear_text_nodes_ops, load_source_image, new_text_node, page_node_count,
12
+ sort_manga_reading_order, text_region_to_pair,
13
+ };
14
+
15
+ const DETECTOR_NAME: &str = "anime-text";
16
+
17
+ pub struct Model(AnimeTextDetector);
18
+
19
+ #[async_trait]
20
+ impl Engine for Model {
21
+ async fn run(&self, ctx: EngineCtx<'_>) -> Result<Vec<Op>> {
22
+ let image = load_source_image(ctx.scene, ctx.page, ctx.blobs)?;
23
+ let det = self.0.inference(&image)?;
24
+
25
+ let mut pairs: Vec<([f32; 4], TextData)> = det
26
+ .text_blocks
27
+ .into_iter()
28
+ .map(|r| text_region_to_pair(r, DETECTOR_NAME))
29
+ .collect();
30
+ sort_manga_reading_order(&mut pairs);
31
+
32
+ let mut ops = clear_text_nodes_ops(ctx.scene, ctx.page);
33
+ let removed = ops.len();
34
+ let insertion_start = page_node_count(ctx.scene, ctx.page).saturating_sub(removed);
35
+ ops.reserve(pairs.len());
36
+ for (at, (bbox, text)) in (insertion_start..).zip(pairs) {
37
+ let node = new_text_node(bbox, text);
38
+ ops.push(Op::AddNode {
39
+ page: ctx.page,
40
+ node,
41
+ at,
42
+ });
43
+ }
44
+ Ok(ops)
45
+ }
46
+ }
47
+
48
+ inventory::submit! {
49
+ EngineInfo {
50
+ id: "anime-text",
51
+ name: "Anime Text YOLO (N)",
52
+ needs: &[],
53
+ produces: &[Artifact::TextBoxes],
54
+ load: |runtime, cpu| Box::pin(async move {
55
+ let m = AnimeTextDetector::load(runtime, cpu).await?;
56
+ Ok(Box::new(Model(m)) as Box<dyn Engine>)
57
+ }),
58
+ }
59
+ }
koharu-app/src/pipeline/engines/mod.rs CHANGED
@@ -4,6 +4,7 @@
4
  //! `inventory::submit! { EngineInfo { … } }`. The registry picks them up
5
  //! automatically at link time.
6
 
 
7
  pub mod aot;
8
  pub mod bubble_segmentation;
9
  pub mod comic_text_bubble;
 
4
  //! `inventory::submit! { EngineInfo { … } }`. The registry picks them up
5
  //! automatically at link time.
6
 
7
+ pub mod anime_text;
8
  pub mod aot;
9
  pub mod bubble_segmentation;
10
  pub mod comic_text_bubble;
koharu-app/src/pipeline/mod.rs CHANGED
@@ -350,3 +350,19 @@ pub fn catalog() -> EngineCatalog {
350
  .collect(),
351
  }
352
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
  .collect(),
351
  }
352
  }
353
+
354
+ #[cfg(test)]
355
+ mod tests {
356
+ use super::*;
357
+
358
+ #[test]
359
+ fn catalog_includes_anime_text_detector() {
360
+ let catalog = catalog();
361
+
362
+ assert!(catalog.detectors.iter().any(|engine| {
363
+ engine.id == "anime-text"
364
+ && engine.name == "Anime Text YOLO (N)"
365
+ && engine.produces.iter().map(String::as_str).eq(["TextBoxes"])
366
+ }));
367
+ }
368
+ }
koharu-ml/Cargo.toml CHANGED
@@ -102,6 +102,10 @@ path = "bin/manga-text-segmentation-2025.rs"
102
  name = "speech-bubble-segmentation"
103
  path = "bin/speech-bubble-segmentation.rs"
104
 
 
 
 
 
105
  [[bin]]
106
  name = "aot-inpainting"
107
  path = "bin/aot-inpainting.rs"
 
102
  name = "speech-bubble-segmentation"
103
  path = "bin/speech-bubble-segmentation.rs"
104
 
105
+ [[bin]]
106
+ name = "anime-text"
107
+ path = "bin/anime-text.rs"
108
+
109
  [[bin]]
110
  name = "aot-inpainting"
111
  path = "bin/aot-inpainting.rs"
koharu-ml/bin/anime-text.rs ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use anyhow::{Result, anyhow, ensure};
2
+ use clap::{Parser, ValueEnum};
3
+ use imageproc::{drawing::draw_hollow_rect_mut, rect::Rect};
4
+ use koharu_ml::anime_text::{AnimeTextDetector, AnimeTextYoloVariant};
5
+ use koharu_runtime::{ComputePolicy, RuntimeManager, default_app_data_root};
6
+ use tokio::runtime::Builder;
7
+
8
+ #[path = "common.rs"]
9
+ mod common;
10
+
11
+ #[derive(Clone, Copy, Debug, ValueEnum)]
12
+ enum Variant {
13
+ N,
14
+ S,
15
+ M,
16
+ L,
17
+ X,
18
+ }
19
+
20
+ impl From<Variant> for AnimeTextYoloVariant {
21
+ fn from(value: Variant) -> Self {
22
+ match value {
23
+ Variant::N => Self::N,
24
+ Variant::S => Self::S,
25
+ Variant::M => Self::M,
26
+ Variant::L => Self::L,
27
+ Variant::X => Self::X,
28
+ }
29
+ }
30
+ }
31
+
32
+ #[derive(Parser)]
33
+ struct Cli {
34
+ #[arg(short, long, value_name = "FILE")]
35
+ input: String,
36
+
37
+ #[arg(short, long, value_name = "FILE")]
38
+ output: String,
39
+
40
+ #[arg(long, value_name = "FILE")]
41
+ json_output: Option<String>,
42
+
43
+ #[arg(long, value_enum, default_value_t = Variant::N)]
44
+ variant: Variant,
45
+
46
+ #[arg(long, default_value_t = 0.25)]
47
+ confidence_threshold: f32,
48
+
49
+ #[arg(long, default_value_t = 0.45)]
50
+ nms_threshold: f32,
51
+
52
+ #[arg(long, default_value_t = false)]
53
+ cpu: bool,
54
+ }
55
+
56
+ fn main() -> Result<()> {
57
+ common::init_tracing();
58
+
59
+ std::thread::Builder::new()
60
+ .name("anime-text-yolo".to_string())
61
+ .stack_size(64 * 1024 * 1024)
62
+ .spawn(|| {
63
+ let runtime = Builder::new_current_thread().enable_all().build()?;
64
+ runtime.block_on(async_main())
65
+ })?
66
+ .join()
67
+ .map_err(|_| anyhow!("anime-text-yolo thread panicked"))?
68
+ }
69
+
70
+ async fn async_main() -> Result<()> {
71
+ let cli = Cli::parse();
72
+ let variant = AnimeTextYoloVariant::from(cli.variant);
73
+
74
+ let runtime = RuntimeManager::new(
75
+ default_app_data_root(),
76
+ if cli.cpu {
77
+ ComputePolicy::CpuOnly
78
+ } else {
79
+ ComputePolicy::PreferGpu
80
+ },
81
+ )?;
82
+ runtime.prepare().await?;
83
+
84
+ let model = AnimeTextDetector::load_variant(&runtime, variant, cli.cpu).await?;
85
+ let bytes = std::fs::read(&cli.input)?;
86
+ let format = image::guess_format(&bytes)?;
87
+ let image = image::load_from_memory_with_format(&bytes, format)?;
88
+ let detection =
89
+ model.inference_with_thresholds(&image, cli.confidence_threshold, cli.nms_threshold)?;
90
+
91
+ ensure!(
92
+ !detection.regions.is_empty(),
93
+ "No anime text blocks detected in the image."
94
+ );
95
+
96
+ let mut image = image.to_rgba8();
97
+ for region in &detection.regions {
98
+ let width = (region.bbox[2] - region.bbox[0]).max(1.0) as u32;
99
+ let height = (region.bbox[3] - region.bbox[1]).max(1.0) as u32;
100
+ draw_hollow_rect_mut(
101
+ &mut image,
102
+ Rect::at(region.bbox[0] as i32, region.bbox[1] as i32).of_size(width, height),
103
+ image::Rgba([255, 0, 0, 255]),
104
+ );
105
+ }
106
+
107
+ image::DynamicImage::ImageRgba8(image).save(&cli.output)?;
108
+ if let Some(path) = &cli.json_output {
109
+ std::fs::write(path, serde_json::to_vec_pretty(&detection)?)?;
110
+ }
111
+ Ok(())
112
+ }
koharu-ml/src/anime_text/mod.rs ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ mod model;
2
+
3
+ use std::{path::Path, path::PathBuf, time::Instant};
4
+
5
+ use anyhow::{Context, Result, bail};
6
+ use candle_core::{DType, Device, IndexOp, Tensor};
7
+ use candle_transformers::object_detection::{Bbox, non_maximum_suppression};
8
+ use image::{
9
+ DynamicImage, Rgb, RgbImage,
10
+ imageops::{self, FilterType},
11
+ };
12
+ use koharu_runtime::RuntimeManager;
13
+ use serde::{Deserialize, Serialize};
14
+ use tracing::instrument;
15
+
16
+ use crate::{device, loading, types::TextRegion};
17
+
18
+ use self::model::{Yolo12, Yolo12Scale};
19
+
20
+ pub const HF_REPO: &str = "mayocream/anime-text-yolo";
21
+ const INPUT_SIZE: u32 = 640;
22
+ const NUM_CLASSES: usize = 1;
23
+ const DEFAULT_VARIANT: AnimeTextYoloVariant = AnimeTextYoloVariant::N;
24
+ const DEFAULT_CONFIDENCE_THRESHOLD: f32 = 0.25;
25
+ const DEFAULT_NMS_THRESHOLD: f32 = 0.45;
26
+ const LETTERBOX_COLOR: u8 = 114;
27
+ const DETECTOR_NAME: &str = "anime-text-yolo";
28
+ const CLASS_NAMES: [&str; NUM_CLASSES] = ["text_block"];
29
+
30
+ koharu_runtime::declare_hf_model_package!(
31
+ id: "model:anime-text-yolo:yolo12n",
32
+ repo: HF_REPO,
33
+ file: "yolo12n_animetext.safetensors",
34
+ bootstrap: false,
35
+ order: 118,
36
+ );
37
+ koharu_runtime::declare_hf_model_package!(
38
+ id: "model:anime-text-yolo:yolo12s",
39
+ repo: HF_REPO,
40
+ file: "yolo12s_animetext.safetensors",
41
+ bootstrap: false,
42
+ order: 119,
43
+ );
44
+ koharu_runtime::declare_hf_model_package!(
45
+ id: "model:anime-text-yolo:yolo12m",
46
+ repo: HF_REPO,
47
+ file: "yolo12m_animetext.safetensors",
48
+ bootstrap: false,
49
+ order: 120,
50
+ );
51
+ koharu_runtime::declare_hf_model_package!(
52
+ id: "model:anime-text-yolo:yolo12l",
53
+ repo: HF_REPO,
54
+ file: "yolo12l_animetext.safetensors",
55
+ bootstrap: false,
56
+ order: 121,
57
+ );
58
+ koharu_runtime::declare_hf_model_package!(
59
+ id: "model:anime-text-yolo:yolo12x",
60
+ repo: HF_REPO,
61
+ file: "yolo12x_animetext.safetensors",
62
+ bootstrap: false,
63
+ order: 122,
64
+ );
65
+
66
+ #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
67
+ #[serde(rename_all = "lowercase")]
68
+ pub enum AnimeTextYoloVariant {
69
+ N,
70
+ S,
71
+ M,
72
+ L,
73
+ X,
74
+ }
75
+
76
+ impl AnimeTextYoloVariant {
77
+ pub fn filename(self) -> &'static str {
78
+ match self {
79
+ Self::N => "yolo12n_animetext.safetensors",
80
+ Self::S => "yolo12s_animetext.safetensors",
81
+ Self::M => "yolo12m_animetext.safetensors",
82
+ Self::L => "yolo12l_animetext.safetensors",
83
+ Self::X => "yolo12x_animetext.safetensors",
84
+ }
85
+ }
86
+
87
+ pub fn as_str(self) -> &'static str {
88
+ match self {
89
+ Self::N => "n",
90
+ Self::S => "s",
91
+ Self::M => "m",
92
+ Self::L => "l",
93
+ Self::X => "x",
94
+ }
95
+ }
96
+
97
+ fn scale(self) -> Yolo12Scale {
98
+ match self {
99
+ Self::N => Yolo12Scale::N,
100
+ Self::S => Yolo12Scale::S,
101
+ Self::M => Yolo12Scale::M,
102
+ Self::L => Yolo12Scale::L,
103
+ Self::X => Yolo12Scale::X,
104
+ }
105
+ }
106
+ }
107
+
108
+ impl std::fmt::Display for AnimeTextYoloVariant {
109
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110
+ f.write_str(self.as_str())
111
+ }
112
+ }
113
+
114
+ #[derive(Debug)]
115
+ pub struct AnimeTextDetector {
116
+ model: Yolo12,
117
+ variant: AnimeTextYoloVariant,
118
+ device: Device,
119
+ dtype: DType,
120
+ }
121
+
122
+ #[derive(Debug, Clone)]
123
+ struct PreparedInput {
124
+ pixel_values: Tensor,
125
+ original_width: u32,
126
+ original_height: u32,
127
+ pad_x: u32,
128
+ pad_y: u32,
129
+ scale: f32,
130
+ }
131
+
132
+ #[derive(Debug, Clone, Serialize)]
133
+ #[serde(rename_all = "camelCase")]
134
+ pub struct AnimeTextDetection {
135
+ pub image_width: u32,
136
+ pub image_height: u32,
137
+ pub variant: AnimeTextYoloVariant,
138
+ pub regions: Vec<AnimeTextRegion>,
139
+ pub text_blocks: Vec<TextRegion>,
140
+ }
141
+
142
+ #[derive(Debug, Clone, Serialize)]
143
+ #[serde(rename_all = "camelCase")]
144
+ pub struct AnimeTextRegion {
145
+ pub label_id: usize,
146
+ pub label: String,
147
+ pub score: f32,
148
+ pub bbox: [f32; 4],
149
+ }
150
+
151
+ impl AnimeTextDetector {
152
+ pub async fn load(runtime: &RuntimeManager, cpu: bool) -> Result<Self> {
153
+ Self::load_variant(runtime, DEFAULT_VARIANT, cpu).await
154
+ }
155
+
156
+ pub async fn load_variant(
157
+ runtime: &RuntimeManager,
158
+ variant: AnimeTextYoloVariant,
159
+ cpu: bool,
160
+ ) -> Result<Self> {
161
+ let weights_path = resolve_model_path(runtime, variant).await?;
162
+ Self::load_from_path(weights_path, variant, cpu)
163
+ }
164
+
165
+ pub fn load_from_path(
166
+ weights_path: impl AsRef<Path>,
167
+ variant: AnimeTextYoloVariant,
168
+ cpu: bool,
169
+ ) -> Result<Self> {
170
+ let device = device(cpu)?;
171
+ let dtype = loading::model_dtype(&device);
172
+ let model = loading::load_mmaped_safetensors_path_with_dtype(
173
+ weights_path.as_ref(),
174
+ &device,
175
+ dtype,
176
+ |vb| Yolo12::load(vb, variant.scale(), NUM_CLASSES),
177
+ )
178
+ .with_context(|| {
179
+ format!(
180
+ "failed to load anime text YOLO {} weights from {}",
181
+ variant,
182
+ weights_path.as_ref().display()
183
+ )
184
+ })?;
185
+
186
+ Ok(Self {
187
+ model,
188
+ variant,
189
+ device,
190
+ dtype,
191
+ })
192
+ }
193
+
194
+ pub fn variant(&self) -> AnimeTextYoloVariant {
195
+ self.variant
196
+ }
197
+
198
+ #[instrument(level = "debug", skip_all)]
199
+ pub fn inference(&self, image: &DynamicImage) -> Result<AnimeTextDetection> {
200
+ self.inference_with_thresholds(image, DEFAULT_CONFIDENCE_THRESHOLD, DEFAULT_NMS_THRESHOLD)
201
+ }
202
+
203
+ #[instrument(level = "debug", skip_all)]
204
+ pub fn inference_with_thresholds(
205
+ &self,
206
+ image: &DynamicImage,
207
+ confidence_threshold: f32,
208
+ nms_threshold: f32,
209
+ ) -> Result<AnimeTextDetection> {
210
+ let started = Instant::now();
211
+ let prepared = self.preprocess(image)?;
212
+ let outputs = self.model.forward(&prepared.pixel_values)?;
213
+ let regions = postprocess(&outputs, &prepared, confidence_threshold, nms_threshold)?;
214
+ let text_blocks = regions_to_text_blocks(&regions);
215
+
216
+ tracing::info!(
217
+ width = image.width(),
218
+ height = image.height(),
219
+ variant = %self.variant,
220
+ detections = regions.len(),
221
+ total_ms = started.elapsed().as_millis(),
222
+ "anime text YOLO timings"
223
+ );
224
+
225
+ Ok(AnimeTextDetection {
226
+ image_width: prepared.original_width,
227
+ image_height: prepared.original_height,
228
+ variant: self.variant,
229
+ regions,
230
+ text_blocks,
231
+ })
232
+ }
233
+
234
+ fn preprocess(&self, image: &DynamicImage) -> Result<PreparedInput> {
235
+ let rgb = image.to_rgb8();
236
+ let (original_width, original_height) = rgb.dimensions();
237
+ let scale = f32::min(
238
+ INPUT_SIZE as f32 / original_width.max(1) as f32,
239
+ INPUT_SIZE as f32 / original_height.max(1) as f32,
240
+ );
241
+ let resized_width = ((original_width as f32 * scale).round() as u32).clamp(1, INPUT_SIZE);
242
+ let resized_height = ((original_height as f32 * scale).round() as u32).clamp(1, INPUT_SIZE);
243
+ let pad_x = (INPUT_SIZE - resized_width) / 2;
244
+ let pad_y = (INPUT_SIZE - resized_height) / 2;
245
+
246
+ let resized = if resized_width == original_width && resized_height == original_height {
247
+ rgb
248
+ } else {
249
+ imageops::resize(&rgb, resized_width, resized_height, FilterType::Triangle)
250
+ };
251
+
252
+ let mut letterboxed =
253
+ RgbImage::from_pixel(INPUT_SIZE, INPUT_SIZE, Rgb([LETTERBOX_COLOR; 3]));
254
+ imageops::overlay(
255
+ &mut letterboxed,
256
+ &resized,
257
+ i64::from(pad_x),
258
+ i64::from(pad_y),
259
+ );
260
+
261
+ let pixel_values = Tensor::from_vec(
262
+ letterboxed.into_raw(),
263
+ (1, INPUT_SIZE as usize, INPUT_SIZE as usize, 3),
264
+ &self.device,
265
+ )?
266
+ .permute((0, 3, 1, 2))?
267
+ .to_dtype(self.dtype)?;
268
+ let pixel_values = (pixel_values * (1.0 / 255.0))?;
269
+
270
+ Ok(PreparedInput {
271
+ pixel_values,
272
+ original_width,
273
+ original_height,
274
+ pad_x,
275
+ pad_y,
276
+ scale,
277
+ })
278
+ }
279
+ }
280
+
281
+ pub async fn prefetch(runtime: &RuntimeManager) -> Result<()> {
282
+ prefetch_variant(runtime, DEFAULT_VARIANT).await
283
+ }
284
+
285
+ pub async fn prefetch_variant(
286
+ runtime: &RuntimeManager,
287
+ variant: AnimeTextYoloVariant,
288
+ ) -> Result<()> {
289
+ let _ = resolve_model_path(runtime, variant).await?;
290
+ Ok(())
291
+ }
292
+
293
+ async fn resolve_model_path(
294
+ runtime: &RuntimeManager,
295
+ variant: AnimeTextYoloVariant,
296
+ ) -> Result<PathBuf> {
297
+ runtime
298
+ .downloads()
299
+ .huggingface_model(HF_REPO, variant.filename())
300
+ .await
301
+ .with_context(|| format!("failed to download {} from {}", variant.filename(), HF_REPO))
302
+ }
303
+
304
+ fn postprocess(
305
+ outputs: &Tensor,
306
+ prepared: &PreparedInput,
307
+ confidence_threshold: f32,
308
+ nms_threshold: f32,
309
+ ) -> Result<Vec<AnimeTextRegion>> {
310
+ let pred = outputs
311
+ .to_dtype(DType::F32)?
312
+ .to_device(&Device::Cpu)?
313
+ .i(0)?;
314
+ let (channels, anchors) = pred.dims2()?;
315
+ let expected_channels = 4 + NUM_CLASSES;
316
+ if channels != expected_channels {
317
+ bail!(
318
+ "unexpected anime text YOLO prediction channels {channels}, expected {expected_channels}"
319
+ );
320
+ }
321
+
322
+ let mut grouped: Vec<Vec<Bbox<usize>>> = (0..NUM_CLASSES).map(|_| Vec::new()).collect();
323
+ for anchor_idx in 0..anchors {
324
+ let values = pred.i((.., anchor_idx))?.to_vec1::<f32>()?;
325
+ let class_scores = &values[4..4 + NUM_CLASSES];
326
+ let Some((label_id, &score)) = class_scores
327
+ .iter()
328
+ .enumerate()
329
+ .max_by(|(_, a), (_, b)| a.total_cmp(b))
330
+ else {
331
+ continue;
332
+ };
333
+ if score < confidence_threshold {
334
+ continue;
335
+ }
336
+
337
+ let bbox = map_bbox_to_original(
338
+ [
339
+ values[0] - values[2] * 0.5,
340
+ values[1] - values[3] * 0.5,
341
+ values[0] + values[2] * 0.5,
342
+ values[1] + values[3] * 0.5,
343
+ ],
344
+ prepared,
345
+ );
346
+ if bbox[2] <= bbox[0] || bbox[3] <= bbox[1] {
347
+ continue;
348
+ }
349
+
350
+ grouped[label_id].push(Bbox {
351
+ xmin: bbox[0],
352
+ ymin: bbox[1],
353
+ xmax: bbox[2],
354
+ ymax: bbox[3],
355
+ confidence: score,
356
+ data: label_id,
357
+ });
358
+ }
359
+
360
+ non_maximum_suppression(&mut grouped, nms_threshold);
361
+
362
+ let mut regions = Vec::new();
363
+ for (label_id, bboxes) in grouped.into_iter().enumerate() {
364
+ let label = CLASS_NAMES
365
+ .get(label_id)
366
+ .copied()
367
+ .unwrap_or("text_block")
368
+ .to_string();
369
+ for bbox in bboxes {
370
+ regions.push(AnimeTextRegion {
371
+ label_id,
372
+ label: label.clone(),
373
+ score: bbox.confidence,
374
+ bbox: [bbox.xmin, bbox.ymin, bbox.xmax, bbox.ymax],
375
+ });
376
+ }
377
+ }
378
+ regions.sort_by(|a, b| b.score.total_cmp(&a.score));
379
+ Ok(regions)
380
+ }
381
+
382
+ fn map_bbox_to_original(bbox: [f32; 4], prepared: &PreparedInput) -> [f32; 4] {
383
+ let width = prepared.original_width as f32;
384
+ let height = prepared.original_height as f32;
385
+ let pad_x = prepared.pad_x as f32;
386
+ let pad_y = prepared.pad_y as f32;
387
+ [
388
+ ((bbox[0] - pad_x) / prepared.scale).clamp(0.0, width),
389
+ ((bbox[1] - pad_y) / prepared.scale).clamp(0.0, height),
390
+ ((bbox[2] - pad_x) / prepared.scale).clamp(0.0, width),
391
+ ((bbox[3] - pad_y) / prepared.scale).clamp(0.0, height),
392
+ ]
393
+ }
394
+
395
+ fn regions_to_text_blocks(regions: &[AnimeTextRegion]) -> Vec<TextRegion> {
396
+ regions
397
+ .iter()
398
+ .filter_map(|region| {
399
+ let width = (region.bbox[2] - region.bbox[0]).max(0.0);
400
+ let height = (region.bbox[3] - region.bbox[1]).max(0.0);
401
+ if width <= 1.0 || height <= 1.0 {
402
+ return None;
403
+ }
404
+ Some(TextRegion {
405
+ x: region.bbox[0],
406
+ y: region.bbox[1],
407
+ width,
408
+ height,
409
+ confidence: region.score,
410
+ detector: Some(DETECTOR_NAME.to_string()),
411
+ ..Default::default()
412
+ })
413
+ })
414
+ .collect()
415
+ }
416
+
417
+ #[cfg(test)]
418
+ mod tests {
419
+ use super::{PreparedInput, map_bbox_to_original};
420
+ use candle_core::{DType, Device, Tensor};
421
+
422
+ #[test]
423
+ fn map_bbox_to_original_removes_letterbox_padding() {
424
+ let prepared = PreparedInput {
425
+ pixel_values: Tensor::zeros((1, 3, 640, 640), DType::F32, &Device::Cpu)
426
+ .expect("tensor"),
427
+ original_width: 1000,
428
+ original_height: 500,
429
+ pad_x: 0,
430
+ pad_y: 160,
431
+ scale: 0.64,
432
+ };
433
+
434
+ let bbox = map_bbox_to_original([100.0, 200.0, 540.0, 440.0], &prepared);
435
+ assert!((bbox[0] - 156.25).abs() < 1e-3);
436
+ assert!((bbox[1] - 62.5).abs() < 1e-3);
437
+ assert!((bbox[2] - 843.75).abs() < 1e-3);
438
+ assert!((bbox[3] - 437.5).abs() < 1e-3);
439
+ }
440
+ }
koharu-ml/src/anime_text/model.rs ADDED
@@ -0,0 +1,982 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use candle_core::{D, IndexOp, Result, Tensor};
2
+ use candle_nn::{BatchNorm, Conv2d, Conv2dConfig, Module, ModuleT, VarBuilder, batch_norm};
3
+
4
+ use crate::ops::{conv2d, conv2d_no_bias};
5
+
6
+ const BN_EPS: f64 = 1e-3;
7
+ const REG_MAX: usize = 16;
8
+
9
+ #[derive(Debug, Clone, Copy, PartialEq, Eq)]
10
+ pub enum Yolo12Scale {
11
+ N,
12
+ S,
13
+ M,
14
+ L,
15
+ X,
16
+ }
17
+
18
+ #[derive(Debug, Clone, Copy)]
19
+ struct Multiples {
20
+ depth: f64,
21
+ width: f64,
22
+ max_channels: usize,
23
+ }
24
+
25
+ impl Yolo12Scale {
26
+ fn multiples(self) -> Multiples {
27
+ match self {
28
+ Self::N => Multiples {
29
+ depth: 0.50,
30
+ width: 0.25,
31
+ max_channels: 1024,
32
+ },
33
+ Self::S => Multiples {
34
+ depth: 0.50,
35
+ width: 0.50,
36
+ max_channels: 1024,
37
+ },
38
+ Self::M => Multiples {
39
+ depth: 0.50,
40
+ width: 1.00,
41
+ max_channels: 512,
42
+ },
43
+ Self::L => Multiples {
44
+ depth: 1.00,
45
+ width: 1.00,
46
+ max_channels: 512,
47
+ },
48
+ Self::X => Multiples {
49
+ depth: 1.00,
50
+ width: 1.50,
51
+ max_channels: 512,
52
+ },
53
+ }
54
+ }
55
+
56
+ fn uses_large_c3k(self) -> bool {
57
+ matches!(self, Self::M | Self::L | Self::X)
58
+ }
59
+
60
+ fn uses_a2_residual(self) -> bool {
61
+ matches!(self, Self::L | Self::X)
62
+ }
63
+ }
64
+
65
+ impl Multiples {
66
+ fn channels(&self, base: usize) -> usize {
67
+ make_divisible((base.min(self.max_channels) as f64) * self.width, 8)
68
+ }
69
+
70
+ fn repeats(&self, base: usize) -> usize {
71
+ if base > 1 {
72
+ ((base as f64 * self.depth).round() as usize).max(1)
73
+ } else {
74
+ base
75
+ }
76
+ }
77
+ }
78
+
79
+ fn make_divisible(value: f64, divisor: usize) -> usize {
80
+ ((value / divisor as f64).ceil() as usize) * divisor
81
+ }
82
+
83
+ #[derive(Debug)]
84
+ struct Upsample {
85
+ scale_factor: usize,
86
+ }
87
+
88
+ impl Upsample {
89
+ fn new(scale_factor: usize) -> Self {
90
+ Self { scale_factor }
91
+ }
92
+ }
93
+
94
+ impl Module for Upsample {
95
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
96
+ let (_, _, h, w) = xs.dims4()?;
97
+ xs.upsample_nearest2d(self.scale_factor * h, self.scale_factor * w)
98
+ }
99
+ }
100
+
101
+ #[derive(Debug)]
102
+ struct ConvBlock {
103
+ conv: Conv2d,
104
+ bn: BatchNorm,
105
+ activation: bool,
106
+ }
107
+
108
+ impl ConvBlock {
109
+ #[allow(clippy::too_many_arguments)]
110
+ fn load(
111
+ vb: VarBuilder,
112
+ in_channels: usize,
113
+ out_channels: usize,
114
+ kernel_size: usize,
115
+ stride: usize,
116
+ padding: Option<usize>,
117
+ groups: usize,
118
+ activation: bool,
119
+ ) -> Result<Self> {
120
+ let cfg = Conv2dConfig {
121
+ padding: padding.unwrap_or(kernel_size / 2),
122
+ stride,
123
+ groups,
124
+ dilation: 1,
125
+ cudnn_fwd_algo: None,
126
+ };
127
+ Ok(Self {
128
+ conv: conv2d_no_bias(in_channels, out_channels, kernel_size, cfg, vb.pp("conv"))?,
129
+ bn: batch_norm(out_channels, BN_EPS, vb.pp("bn"))?,
130
+ activation,
131
+ })
132
+ }
133
+ }
134
+
135
+ impl Module for ConvBlock {
136
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
137
+ let xs = self.conv.forward(xs)?;
138
+ let xs = self.bn.forward_t(&xs, false)?;
139
+ if self.activation {
140
+ candle_nn::ops::silu(&xs)
141
+ } else {
142
+ Ok(xs)
143
+ }
144
+ }
145
+ }
146
+
147
+ #[derive(Debug)]
148
+ struct Bottleneck {
149
+ cv1: ConvBlock,
150
+ cv2: ConvBlock,
151
+ residual: bool,
152
+ }
153
+
154
+ impl Bottleneck {
155
+ fn load(
156
+ vb: VarBuilder,
157
+ in_channels: usize,
158
+ out_channels: usize,
159
+ shortcut: bool,
160
+ groups: usize,
161
+ kernel_size: usize,
162
+ expansion: f64,
163
+ ) -> Result<Self> {
164
+ let hidden = (out_channels as f64 * expansion) as usize;
165
+ Ok(Self {
166
+ cv1: ConvBlock::load(
167
+ vb.pp("cv1"),
168
+ in_channels,
169
+ hidden,
170
+ kernel_size,
171
+ 1,
172
+ None,
173
+ 1,
174
+ true,
175
+ )?,
176
+ cv2: ConvBlock::load(
177
+ vb.pp("cv2"),
178
+ hidden,
179
+ out_channels,
180
+ kernel_size,
181
+ 1,
182
+ None,
183
+ groups,
184
+ true,
185
+ )?,
186
+ residual: shortcut && in_channels == out_channels,
187
+ })
188
+ }
189
+ }
190
+
191
+ impl Module for Bottleneck {
192
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
193
+ let ys = self.cv2.forward(&self.cv1.forward(xs)?)?;
194
+ if self.residual { xs + ys } else { Ok(ys) }
195
+ }
196
+ }
197
+
198
+ #[derive(Debug)]
199
+ struct C3k {
200
+ cv1: ConvBlock,
201
+ cv2: ConvBlock,
202
+ cv3: ConvBlock,
203
+ blocks: Vec<Bottleneck>,
204
+ }
205
+
206
+ #[derive(Debug, Clone, Copy)]
207
+ struct C3kOptions {
208
+ shortcut: bool,
209
+ groups: usize,
210
+ expansion: f64,
211
+ kernel_size: usize,
212
+ }
213
+
214
+ impl C3k {
215
+ fn load(
216
+ vb: VarBuilder,
217
+ in_channels: usize,
218
+ out_channels: usize,
219
+ repeats: usize,
220
+ options: C3kOptions,
221
+ ) -> Result<Self> {
222
+ let hidden = (out_channels as f64 * options.expansion) as usize;
223
+ let mut blocks = Vec::with_capacity(repeats);
224
+ for index in 0..repeats {
225
+ blocks.push(Bottleneck::load(
226
+ vb.pp(format!("m.{index}")),
227
+ hidden,
228
+ hidden,
229
+ options.shortcut,
230
+ options.groups,
231
+ options.kernel_size,
232
+ 1.0,
233
+ )?);
234
+ }
235
+ Ok(Self {
236
+ cv1: ConvBlock::load(vb.pp("cv1"), in_channels, hidden, 1, 1, None, 1, true)?,
237
+ cv2: ConvBlock::load(vb.pp("cv2"), in_channels, hidden, 1, 1, None, 1, true)?,
238
+ cv3: ConvBlock::load(vb.pp("cv3"), hidden * 2, out_channels, 1, 1, None, 1, true)?,
239
+ blocks,
240
+ })
241
+ }
242
+ }
243
+
244
+ impl Module for C3k {
245
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
246
+ let mut y1 = self.cv1.forward(xs)?;
247
+ for block in &self.blocks {
248
+ y1 = block.forward(&y1)?;
249
+ }
250
+ let y2 = self.cv2.forward(xs)?;
251
+ self.cv3.forward(&Tensor::cat(&[&y1, &y2], 1)?)
252
+ }
253
+ }
254
+
255
+ #[derive(Debug)]
256
+ enum C3k2Block {
257
+ Bottleneck(Box<Bottleneck>),
258
+ C3k(Box<C3k>),
259
+ }
260
+
261
+ impl Module for C3k2Block {
262
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
263
+ match self {
264
+ Self::Bottleneck(block) => block.forward(xs),
265
+ Self::C3k(block) => block.forward(xs),
266
+ }
267
+ }
268
+ }
269
+
270
+ #[derive(Debug)]
271
+ struct C3k2 {
272
+ cv1: ConvBlock,
273
+ cv2: ConvBlock,
274
+ blocks: Vec<C3k2Block>,
275
+ }
276
+
277
+ #[derive(Debug, Clone, Copy)]
278
+ struct C3k2Options {
279
+ use_c3k: bool,
280
+ expansion: f64,
281
+ groups: usize,
282
+ shortcut: bool,
283
+ }
284
+
285
+ impl C3k2 {
286
+ fn load(
287
+ vb: VarBuilder,
288
+ in_channels: usize,
289
+ out_channels: usize,
290
+ repeats: usize,
291
+ options: C3k2Options,
292
+ ) -> Result<Self> {
293
+ let hidden = (out_channels as f64 * options.expansion) as usize;
294
+ let mut blocks = Vec::with_capacity(repeats);
295
+ for index in 0..repeats {
296
+ let vb = vb.pp(format!("m.{index}"));
297
+ let block = if options.use_c3k {
298
+ C3k2Block::C3k(Box::new(C3k::load(
299
+ vb,
300
+ hidden,
301
+ hidden,
302
+ 2,
303
+ C3kOptions {
304
+ shortcut: options.shortcut,
305
+ groups: options.groups,
306
+ expansion: 0.5,
307
+ kernel_size: 3,
308
+ },
309
+ )?))
310
+ } else {
311
+ C3k2Block::Bottleneck(Box::new(Bottleneck::load(
312
+ vb,
313
+ hidden,
314
+ hidden,
315
+ options.shortcut,
316
+ options.groups,
317
+ 3,
318
+ 0.5,
319
+ )?))
320
+ };
321
+ blocks.push(block);
322
+ }
323
+ Ok(Self {
324
+ cv1: ConvBlock::load(vb.pp("cv1"), in_channels, hidden * 2, 1, 1, None, 1, true)?,
325
+ cv2: ConvBlock::load(
326
+ vb.pp("cv2"),
327
+ (2 + repeats) * hidden,
328
+ out_channels,
329
+ 1,
330
+ 1,
331
+ None,
332
+ 1,
333
+ true,
334
+ )?,
335
+ blocks,
336
+ })
337
+ }
338
+ }
339
+
340
+ impl Module for C3k2 {
341
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
342
+ let mut ys = self.cv1.forward(xs)?.chunk(2, 1)?;
343
+ for block in &self.blocks {
344
+ ys.push(block.forward(ys.last().expect("c3k2 chunk"))?);
345
+ }
346
+ let refs = ys.iter().collect::<Vec<_>>();
347
+ self.cv2.forward(&Tensor::cat(&refs, 1)?)
348
+ }
349
+ }
350
+
351
+ #[derive(Debug)]
352
+ struct AreaAttention {
353
+ area: usize,
354
+ num_heads: usize,
355
+ head_dim: usize,
356
+ qkv: ConvBlock,
357
+ proj: ConvBlock,
358
+ pe: ConvBlock,
359
+ }
360
+
361
+ impl AreaAttention {
362
+ fn load(vb: VarBuilder, dim: usize, num_heads: usize, area: usize) -> Result<Self> {
363
+ let head_dim = dim / num_heads;
364
+ let all_head_dim = head_dim * num_heads;
365
+ Ok(Self {
366
+ area,
367
+ num_heads,
368
+ head_dim,
369
+ qkv: ConvBlock::load(vb.pp("qkv"), dim, all_head_dim * 3, 1, 1, None, 1, false)?,
370
+ proj: ConvBlock::load(vb.pp("proj"), all_head_dim, dim, 1, 1, None, 1, false)?,
371
+ pe: ConvBlock::load(vb.pp("pe"), all_head_dim, dim, 7, 1, Some(3), dim, false)?,
372
+ })
373
+ }
374
+ }
375
+
376
+ impl Module for AreaAttention {
377
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
378
+ let (batch, channels, height, width) = xs.dims4()?;
379
+ let num_tokens = height * width;
380
+ let qkv = self
381
+ .qkv
382
+ .forward(xs)?
383
+ .flatten_from(2)?
384
+ .transpose(1, 2)?
385
+ .contiguous()?;
386
+ let qkv = if self.area > 1 {
387
+ qkv.reshape((batch * self.area, num_tokens / self.area, channels * 3))?
388
+ } else {
389
+ qkv
390
+ };
391
+ let (area_batch, area_tokens, _) = qkv.dims3()?;
392
+ let qkv = qkv
393
+ .reshape((area_batch, area_tokens, self.num_heads, self.head_dim * 3))?
394
+ .permute((0, 2, 3, 1))?
395
+ .contiguous()?;
396
+
397
+ let q = qkv.narrow(2, 0, self.head_dim)?;
398
+ let k = qkv.narrow(2, self.head_dim, self.head_dim)?;
399
+ let v = qkv.narrow(2, self.head_dim * 2, self.head_dim)?;
400
+ let attn = (q.transpose(2, 3)?.matmul(&k)? * (self.head_dim as f64).powf(-0.5))?;
401
+ let attn = candle_nn::ops::softmax(&attn, D::Minus1)?;
402
+ let ys = v.matmul(&attn.transpose(2, 3)?)?;
403
+ let ys = ys.permute((0, 3, 1, 2))?.contiguous()?;
404
+ let v = v.permute((0, 3, 1, 2))?.contiguous()?;
405
+
406
+ let (ys, v) = if self.area > 1 {
407
+ (
408
+ ys.reshape((batch, num_tokens, channels))?,
409
+ v.reshape((batch, num_tokens, channels))?,
410
+ )
411
+ } else {
412
+ (ys, v)
413
+ };
414
+
415
+ let ys = ys
416
+ .reshape((batch, height, width, channels))?
417
+ .permute((0, 3, 1, 2))?
418
+ .contiguous()?;
419
+ let v = v
420
+ .reshape((batch, height, width, channels))?
421
+ .permute((0, 3, 1, 2))?
422
+ .contiguous()?;
423
+ self.proj.forward(&(ys + self.pe.forward(&v)?)?)
424
+ }
425
+ }
426
+
427
+ #[derive(Debug)]
428
+ struct AreaBlock {
429
+ attn: AreaAttention,
430
+ mlp0: ConvBlock,
431
+ mlp1: ConvBlock,
432
+ }
433
+
434
+ impl AreaBlock {
435
+ fn load(
436
+ vb: VarBuilder,
437
+ dim: usize,
438
+ num_heads: usize,
439
+ mlp_ratio: f64,
440
+ area: usize,
441
+ ) -> Result<Self> {
442
+ let mlp_hidden = (dim as f64 * mlp_ratio) as usize;
443
+ Ok(Self {
444
+ attn: AreaAttention::load(vb.pp("attn"), dim, num_heads, area)?,
445
+ mlp0: ConvBlock::load(vb.pp("mlp.0"), dim, mlp_hidden, 1, 1, None, 1, true)?,
446
+ mlp1: ConvBlock::load(vb.pp("mlp.1"), mlp_hidden, dim, 1, 1, None, 1, false)?,
447
+ })
448
+ }
449
+ }
450
+
451
+ impl Module for AreaBlock {
452
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
453
+ let xs = (xs + self.attn.forward(xs)?)?;
454
+ let mlp = self.mlp1.forward(&self.mlp0.forward(&xs)?)?;
455
+ xs + mlp
456
+ }
457
+ }
458
+
459
+ #[derive(Debug)]
460
+ enum A2C2fBlock {
461
+ Attention(Vec<AreaBlock>),
462
+ C3k(Box<C3k>),
463
+ }
464
+
465
+ impl Module for A2C2fBlock {
466
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
467
+ match self {
468
+ Self::Attention(blocks) => {
469
+ let mut ys = xs.clone();
470
+ for block in blocks {
471
+ ys = block.forward(&ys)?;
472
+ }
473
+ Ok(ys)
474
+ }
475
+ Self::C3k(block) => block.forward(xs),
476
+ }
477
+ }
478
+ }
479
+
480
+ #[derive(Debug)]
481
+ struct A2C2f {
482
+ cv1: ConvBlock,
483
+ cv2: ConvBlock,
484
+ gamma: Option<Tensor>,
485
+ blocks: Vec<A2C2fBlock>,
486
+ }
487
+
488
+ #[allow(clippy::too_many_arguments)]
489
+ impl A2C2f {
490
+ fn load(
491
+ vb: VarBuilder,
492
+ in_channels: usize,
493
+ out_channels: usize,
494
+ repeats: usize,
495
+ attention: bool,
496
+ area: usize,
497
+ residual: bool,
498
+ mlp_ratio: f64,
499
+ expansion: f64,
500
+ groups: usize,
501
+ shortcut: bool,
502
+ ) -> Result<Self> {
503
+ let hidden = (out_channels as f64 * expansion) as usize;
504
+ let gamma = if attention && residual {
505
+ Some(vb.get(out_channels, "gamma")?)
506
+ } else {
507
+ None
508
+ };
509
+ let mut blocks = Vec::with_capacity(repeats);
510
+ for index in 0..repeats {
511
+ let block_vb = vb.pp(format!("m.{index}"));
512
+ let block = if attention {
513
+ let mut area_blocks = Vec::with_capacity(2);
514
+ for block_index in 0..2 {
515
+ area_blocks.push(AreaBlock::load(
516
+ block_vb.pp(block_index),
517
+ hidden,
518
+ hidden / 32,
519
+ mlp_ratio,
520
+ area,
521
+ )?);
522
+ }
523
+ A2C2fBlock::Attention(area_blocks)
524
+ } else {
525
+ A2C2fBlock::C3k(Box::new(C3k::load(
526
+ block_vb,
527
+ hidden,
528
+ hidden,
529
+ 2,
530
+ C3kOptions {
531
+ shortcut,
532
+ groups,
533
+ expansion: 0.5,
534
+ kernel_size: 3,
535
+ },
536
+ )?))
537
+ };
538
+ blocks.push(block);
539
+ }
540
+ Ok(Self {
541
+ cv1: ConvBlock::load(vb.pp("cv1"), in_channels, hidden, 1, 1, None, 1, true)?,
542
+ cv2: ConvBlock::load(
543
+ vb.pp("cv2"),
544
+ (1 + repeats) * hidden,
545
+ out_channels,
546
+ 1,
547
+ 1,
548
+ None,
549
+ 1,
550
+ true,
551
+ )?,
552
+ gamma,
553
+ blocks,
554
+ })
555
+ }
556
+ }
557
+
558
+ impl Module for A2C2f {
559
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
560
+ let mut ys = vec![self.cv1.forward(xs)?];
561
+ for block in &self.blocks {
562
+ ys.push(block.forward(ys.last().expect("a2c2f output"))?);
563
+ }
564
+ let refs = ys.iter().collect::<Vec<_>>();
565
+ let ys = self.cv2.forward(&Tensor::cat(&refs, 1)?)?;
566
+ match &self.gamma {
567
+ Some(gamma) => {
568
+ xs + ys.broadcast_mul(&gamma.reshape((1, gamma.elem_count(), 1, 1))?)?
569
+ }
570
+ None => Ok(ys),
571
+ }
572
+ }
573
+ }
574
+
575
+ #[derive(Debug)]
576
+ struct Yolo12Backbone {
577
+ l0: ConvBlock,
578
+ l1: ConvBlock,
579
+ l2: C3k2,
580
+ l3: ConvBlock,
581
+ l4: C3k2,
582
+ l5: ConvBlock,
583
+ l6: A2C2f,
584
+ l7: ConvBlock,
585
+ l8: A2C2f,
586
+ }
587
+
588
+ impl Yolo12Backbone {
589
+ fn load(vb: VarBuilder, scale: Yolo12Scale) -> Result<Self> {
590
+ let m = scale.multiples();
591
+ let c64 = m.channels(64);
592
+ let c128 = m.channels(128);
593
+ let c256 = m.channels(256);
594
+ let c512 = m.channels(512);
595
+ let c1024 = m.channels(1024);
596
+ let a2_residual = scale.uses_a2_residual();
597
+ let mlp_ratio = if a2_residual { 1.2 } else { 2.0 };
598
+
599
+ Ok(Self {
600
+ l0: ConvBlock::load(vb.pp("model.0"), 3, c64, 3, 2, None, 1, true)?,
601
+ l1: ConvBlock::load(vb.pp("model.1"), c64, c128, 3, 2, None, 1, true)?,
602
+ l2: C3k2::load(
603
+ vb.pp("model.2"),
604
+ c128,
605
+ c256,
606
+ m.repeats(2),
607
+ C3k2Options {
608
+ use_c3k: scale.uses_large_c3k(),
609
+ expansion: 0.25,
610
+ groups: 1,
611
+ shortcut: true,
612
+ },
613
+ )?,
614
+ l3: ConvBlock::load(vb.pp("model.3"), c256, c256, 3, 2, None, 1, true)?,
615
+ l4: C3k2::load(
616
+ vb.pp("model.4"),
617
+ c256,
618
+ c512,
619
+ m.repeats(2),
620
+ C3k2Options {
621
+ use_c3k: scale.uses_large_c3k(),
622
+ expansion: 0.25,
623
+ groups: 1,
624
+ shortcut: true,
625
+ },
626
+ )?,
627
+ l5: ConvBlock::load(vb.pp("model.5"), c512, c512, 3, 2, None, 1, true)?,
628
+ l6: A2C2f::load(
629
+ vb.pp("model.6"),
630
+ c512,
631
+ c512,
632
+ m.repeats(4),
633
+ true,
634
+ 4,
635
+ a2_residual,
636
+ mlp_ratio,
637
+ 0.5,
638
+ 1,
639
+ true,
640
+ )?,
641
+ l7: ConvBlock::load(vb.pp("model.7"), c512, c1024, 3, 2, None, 1, true)?,
642
+ l8: A2C2f::load(
643
+ vb.pp("model.8"),
644
+ c1024,
645
+ c1024,
646
+ m.repeats(4),
647
+ true,
648
+ 1,
649
+ a2_residual,
650
+ mlp_ratio,
651
+ 0.5,
652
+ 1,
653
+ true,
654
+ )?,
655
+ })
656
+ }
657
+
658
+ fn forward(&self, xs: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
659
+ let x0 = self.l0.forward(xs)?;
660
+ let x1 = self.l1.forward(&x0)?;
661
+ let x2 = self.l2.forward(&x1)?;
662
+ let x3 = self.l3.forward(&x2)?;
663
+ let x4 = self.l4.forward(&x3)?;
664
+ let x5 = self.l5.forward(&x4)?;
665
+ let x6 = self.l6.forward(&x5)?;
666
+ let x7 = self.l7.forward(&x6)?;
667
+ let x8 = self.l8.forward(&x7)?;
668
+ Ok((x4, x6, x8))
669
+ }
670
+ }
671
+
672
+ #[derive(Debug)]
673
+ struct Yolo12Neck {
674
+ upsample: Upsample,
675
+ l11: A2C2f,
676
+ l14: A2C2f,
677
+ l15: ConvBlock,
678
+ l17: A2C2f,
679
+ l18: ConvBlock,
680
+ l20: C3k2,
681
+ }
682
+
683
+ impl Yolo12Neck {
684
+ fn load(vb: VarBuilder, scale: Yolo12Scale) -> Result<Self> {
685
+ let m = scale.multiples();
686
+ let c256 = m.channels(256);
687
+ let c512 = m.channels(512);
688
+ let c1024 = m.channels(1024);
689
+ let repeats = m.repeats(2);
690
+ Ok(Self {
691
+ upsample: Upsample::new(2),
692
+ l11: A2C2f::load(
693
+ vb.pp("model.11"),
694
+ c1024 + c512,
695
+ c512,
696
+ repeats,
697
+ false,
698
+ 1,
699
+ false,
700
+ 2.0,
701
+ 0.5,
702
+ 1,
703
+ true,
704
+ )?,
705
+ l14: A2C2f::load(
706
+ vb.pp("model.14"),
707
+ c512 + c512,
708
+ c256,
709
+ repeats,
710
+ false,
711
+ 1,
712
+ false,
713
+ 2.0,
714
+ 0.5,
715
+ 1,
716
+ true,
717
+ )?,
718
+ l15: ConvBlock::load(vb.pp("model.15"), c256, c256, 3, 2, None, 1, true)?,
719
+ l17: A2C2f::load(
720
+ vb.pp("model.17"),
721
+ c256 + c512,
722
+ c512,
723
+ repeats,
724
+ false,
725
+ 1,
726
+ false,
727
+ 2.0,
728
+ 0.5,
729
+ 1,
730
+ true,
731
+ )?,
732
+ l18: ConvBlock::load(vb.pp("model.18"), c512, c512, 3, 2, None, 1, true)?,
733
+ l20: C3k2::load(
734
+ vb.pp("model.20"),
735
+ c512 + c1024,
736
+ c1024,
737
+ repeats,
738
+ C3k2Options {
739
+ use_c3k: true,
740
+ expansion: 0.5,
741
+ groups: 1,
742
+ shortcut: true,
743
+ },
744
+ )?,
745
+ })
746
+ }
747
+
748
+ fn forward(&self, p3: &Tensor, p4: &Tensor, p5: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
749
+ let x11 = self
750
+ .l11
751
+ .forward(&Tensor::cat(&[&self.upsample.forward(p5)?, p4], 1)?)?;
752
+ let x14 = self
753
+ .l14
754
+ .forward(&Tensor::cat(&[&self.upsample.forward(&x11)?, p3], 1)?)?;
755
+ let x17 = self
756
+ .l17
757
+ .forward(&Tensor::cat(&[&self.l15.forward(&x14)?, &x11], 1)?)?;
758
+ let x20 = self
759
+ .l20
760
+ .forward(&Tensor::cat(&[&self.l18.forward(&x17)?, p5], 1)?)?;
761
+ Ok((x14, x17, x20))
762
+ }
763
+ }
764
+
765
+ #[derive(Debug)]
766
+ struct Dfl {
767
+ conv: Conv2d,
768
+ reg_max: usize,
769
+ }
770
+
771
+ impl Dfl {
772
+ fn load(vb: VarBuilder, reg_max: usize) -> Result<Self> {
773
+ Ok(Self {
774
+ conv: conv2d_no_bias(reg_max, 1, 1, Default::default(), vb.pp("conv"))?,
775
+ reg_max,
776
+ })
777
+ }
778
+ }
779
+
780
+ impl Module for Dfl {
781
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
782
+ let (batch, _, anchors) = xs.dims3()?;
783
+ let xs = xs
784
+ .reshape((batch, 4, self.reg_max, anchors))?
785
+ .transpose(2, 1)?;
786
+ let xs = candle_nn::ops::softmax(&xs, 1)?;
787
+ self.conv.forward(&xs)?.reshape((batch, 4, anchors))
788
+ }
789
+ }
790
+
791
+ #[derive(Debug)]
792
+ struct DetectCv3 {
793
+ dw0: ConvBlock,
794
+ pw0: ConvBlock,
795
+ dw1: ConvBlock,
796
+ pw1: ConvBlock,
797
+ conv: Conv2d,
798
+ }
799
+
800
+ impl DetectCv3 {
801
+ fn load(vb: VarBuilder, in_channels: usize, hidden: usize, num_classes: usize) -> Result<Self> {
802
+ Ok(Self {
803
+ dw0: ConvBlock::load(
804
+ vb.pp("0.0"),
805
+ in_channels,
806
+ in_channels,
807
+ 3,
808
+ 1,
809
+ None,
810
+ in_channels,
811
+ true,
812
+ )?,
813
+ pw0: ConvBlock::load(vb.pp("0.1"), in_channels, hidden, 1, 1, None, 1, true)?,
814
+ dw1: ConvBlock::load(vb.pp("1.0"), hidden, hidden, 3, 1, None, hidden, true)?,
815
+ pw1: ConvBlock::load(vb.pp("1.1"), hidden, hidden, 1, 1, None, 1, true)?,
816
+ conv: conv2d(hidden, num_classes, 1, Default::default(), vb.pp("2"))?,
817
+ })
818
+ }
819
+ }
820
+
821
+ impl Module for DetectCv3 {
822
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
823
+ let xs = self.pw0.forward(&self.dw0.forward(xs)?)?;
824
+ let xs = self.pw1.forward(&self.dw1.forward(&xs)?)?;
825
+ self.conv.forward(&xs)
826
+ }
827
+ }
828
+
829
+ #[derive(Debug)]
830
+ struct DetectionHead {
831
+ dfl: Dfl,
832
+ cv2: [(ConvBlock, ConvBlock, Conv2d); 3],
833
+ cv3: [DetectCv3; 3],
834
+ reg_max: usize,
835
+ no: usize,
836
+ }
837
+
838
+ impl DetectionHead {
839
+ fn load(vb: VarBuilder, num_classes: usize, filters: (usize, usize, usize)) -> Result<Self> {
840
+ let c2 = filters.0.div_ceil(4).max(REG_MAX * 4).max(16);
841
+ let c3 = filters.0.max(num_classes.min(100));
842
+ Ok(Self {
843
+ dfl: Dfl::load(vb.pp("dfl"), REG_MAX)?,
844
+ cv2: [
845
+ Self::load_cv2(vb.pp("cv2.0"), filters.0, c2)?,
846
+ Self::load_cv2(vb.pp("cv2.1"), filters.1, c2)?,
847
+ Self::load_cv2(vb.pp("cv2.2"), filters.2, c2)?,
848
+ ],
849
+ cv3: [
850
+ DetectCv3::load(vb.pp("cv3.0"), filters.0, c3, num_classes)?,
851
+ DetectCv3::load(vb.pp("cv3.1"), filters.1, c3, num_classes)?,
852
+ DetectCv3::load(vb.pp("cv3.2"), filters.2, c3, num_classes)?,
853
+ ],
854
+ reg_max: REG_MAX,
855
+ no: num_classes + REG_MAX * 4,
856
+ })
857
+ }
858
+
859
+ fn load_cv2(
860
+ vb: VarBuilder,
861
+ in_channels: usize,
862
+ hidden: usize,
863
+ ) -> Result<(ConvBlock, ConvBlock, Conv2d)> {
864
+ Ok((
865
+ ConvBlock::load(vb.pp("0"), in_channels, hidden, 3, 1, None, 1, true)?,
866
+ ConvBlock::load(vb.pp("1"), hidden, hidden, 3, 1, None, 1, true)?,
867
+ conv2d(hidden, REG_MAX * 4, 1, Default::default(), vb.pp("2"))?,
868
+ ))
869
+ }
870
+
871
+ fn forward_cv2(block: &(ConvBlock, ConvBlock, Conv2d), xs: &Tensor) -> Result<Tensor> {
872
+ block.2.forward(&block.1.forward(&block.0.forward(xs)?)?)
873
+ }
874
+
875
+ fn forward(&self, xs0: &Tensor, xs1: &Tensor, xs2: &Tensor) -> Result<Tensor> {
876
+ let xs0 = Tensor::cat(
877
+ &[
878
+ &Self::forward_cv2(&self.cv2[0], xs0)?,
879
+ &self.cv3[0].forward(xs0)?,
880
+ ],
881
+ 1,
882
+ )?;
883
+ let xs1 = Tensor::cat(
884
+ &[
885
+ &Self::forward_cv2(&self.cv2[1], xs1)?,
886
+ &self.cv3[1].forward(xs1)?,
887
+ ],
888
+ 1,
889
+ )?;
890
+ let xs2 = Tensor::cat(
891
+ &[
892
+ &Self::forward_cv2(&self.cv2[2], xs2)?,
893
+ &self.cv3[2].forward(xs2)?,
894
+ ],
895
+ 1,
896
+ )?;
897
+
898
+ let (anchors, strides) = make_anchors(&xs0, &xs1, &xs2, (8, 16, 32), 0.5)?;
899
+ let anchors = anchors.transpose(0, 1)?.unsqueeze(0)?;
900
+ let strides = strides.transpose(0, 1)?;
901
+
902
+ let reshape = |xs: &Tensor| {
903
+ let batch = xs.dim(0)?;
904
+ xs.reshape((batch, self.no, xs.elem_count() / (batch * self.no)))
905
+ };
906
+ let ys0 = reshape(&xs0)?;
907
+ let ys1 = reshape(&xs1)?;
908
+ let ys2 = reshape(&xs2)?;
909
+ let x_cat = Tensor::cat(&[&ys0, &ys1, &ys2], 2)?;
910
+ let box_ = x_cat.i((.., ..self.reg_max * 4, ..))?;
911
+ let cls = x_cat.i((.., self.reg_max * 4.., ..))?;
912
+ let dbox = dist2bbox(&self.dfl.forward(&box_)?, &anchors)?.broadcast_mul(&strides)?;
913
+ Tensor::cat(&[&dbox, &candle_nn::ops::sigmoid(&cls)?], 1)
914
+ }
915
+ }
916
+
917
+ fn make_anchors(
918
+ xs0: &Tensor,
919
+ xs1: &Tensor,
920
+ xs2: &Tensor,
921
+ strides: (usize, usize, usize),
922
+ grid_cell_offset: f64,
923
+ ) -> Result<(Tensor, Tensor)> {
924
+ let device = xs0.device();
925
+ let dtype = xs0.dtype();
926
+ let mut anchor_points = Vec::with_capacity(3);
927
+ let mut stride_tensors = Vec::with_capacity(3);
928
+ for (xs, stride) in [(xs0, strides.0), (xs1, strides.1), (xs2, strides.2)] {
929
+ let (_, _, h, w) = xs.dims4()?;
930
+ let sx = (Tensor::arange(0, w as u32, device)?.to_dtype(dtype)? + grid_cell_offset)?;
931
+ let sy = (Tensor::arange(0, h as u32, device)?.to_dtype(dtype)? + grid_cell_offset)?;
932
+ let sx = sx
933
+ .reshape((1, sx.elem_count()))?
934
+ .repeat((h, 1))?
935
+ .flatten_all()?;
936
+ let sy = sy
937
+ .reshape((sy.elem_count(), 1))?
938
+ .repeat((1, w))?
939
+ .flatten_all()?;
940
+ anchor_points.push(Tensor::stack(&[&sx, &sy], D::Minus1)?);
941
+ stride_tensors.push((Tensor::ones(h * w, dtype, device)? * stride as f64)?);
942
+ }
943
+ let anchor_points = Tensor::cat(anchor_points.as_slice(), 0)?;
944
+ let stride_tensor = Tensor::cat(stride_tensors.as_slice(), 0)?.unsqueeze(1)?;
945
+ Ok((anchor_points, stride_tensor))
946
+ }
947
+
948
+ fn dist2bbox(distance: &Tensor, anchor_points: &Tensor) -> Result<Tensor> {
949
+ let chunks = distance.chunk(2, 1)?;
950
+ let lt = &chunks[0];
951
+ let rb = &chunks[1];
952
+ let x1y1 = anchor_points.sub(lt)?;
953
+ let x2y2 = anchor_points.add(rb)?;
954
+ let c_xy = ((&x1y1 + &x2y2)? * 0.5)?;
955
+ let wh = (&x2y2 - &x1y1)?;
956
+ Tensor::cat(&[&c_xy, &wh], 1)
957
+ }
958
+
959
+ #[derive(Debug)]
960
+ pub struct Yolo12 {
961
+ backbone: Yolo12Backbone,
962
+ neck: Yolo12Neck,
963
+ head: DetectionHead,
964
+ }
965
+
966
+ impl Yolo12 {
967
+ pub fn load(vb: VarBuilder, scale: Yolo12Scale, num_classes: usize) -> Result<Self> {
968
+ let m = scale.multiples();
969
+ let filters = (m.channels(256), m.channels(512), m.channels(1024));
970
+ Ok(Self {
971
+ backbone: Yolo12Backbone::load(vb.clone(), scale)?,
972
+ neck: Yolo12Neck::load(vb.clone(), scale)?,
973
+ head: DetectionHead::load(vb.pp("model.21"), num_classes, filters)?,
974
+ })
975
+ }
976
+
977
+ pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
978
+ let (p3, p4, p5) = self.backbone.forward(xs)?;
979
+ let (h1, h2, h3) = self.neck.forward(&p3, &p4, &p5)?;
980
+ self.head.forward(&h1, &h2, &h3)
981
+ }
982
+ }
koharu-ml/src/lib.rs CHANGED
@@ -1,5 +1,6 @@
1
  mod hf_hub;
2
 
 
3
  pub mod aot_inpainting;
4
  pub mod comic_text_bubble_detector;
5
  pub mod comic_text_detector;
 
1
  mod hf_hub;
2
 
3
+ pub mod anime_text;
4
  pub mod aot_inpainting;
5
  pub mod comic_text_bubble_detector;
6
  pub mod comic_text_detector;
koharu-ml/tests/anime_text.rs ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use std::path::Path;
2
+
3
+ use koharu_ml::anime_text::AnimeTextDetector;
4
+
5
+ mod support;
6
+
7
+ #[tokio::test]
8
+ #[ignore = "requires model download and is not critical for CI"]
9
+ async fn anime_text_yolo() -> anyhow::Result<()> {
10
+ let runtime = support::cpu_runtime();
11
+ let model = AnimeTextDetector::load(&runtime, false).await?;
12
+
13
+ let image = image::open(Path::new(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures/1.jpg"))?;
14
+ let detection = model.inference(&image)?;
15
+
16
+ assert_eq!(detection.image_width, image.width());
17
+ assert_eq!(detection.image_height, image.height());
18
+ assert!(
19
+ !detection.text_blocks.is_empty(),
20
+ "expected anime text YOLO to detect text blocks"
21
+ );
22
+ assert!(
23
+ detection
24
+ .text_blocks
25
+ .iter()
26
+ .all(|block| block.detector.as_deref() == Some("anime-text-yolo"))
27
+ );
28
+
29
+ Ok(())
30
+ }