Mayo commited on
feat: anime-text model
Browse files- README.md +1 -0
- koharu-app/src/config.rs +1 -1
- koharu-app/src/pipeline/engines/anime_text.rs +59 -0
- koharu-app/src/pipeline/engines/mod.rs +1 -0
- koharu-app/src/pipeline/mod.rs +16 -0
- koharu-ml/Cargo.toml +4 -0
- koharu-ml/bin/anime-text.rs +112 -0
- koharu-ml/src/anime_text/mod.rs +440 -0
- koharu-ml/src/anime_text/model.rs +982 -0
- koharu-ml/src/lib.rs +1 -0
- koharu-ml/tests/anime_text.rs +30 -0
README.md
CHANGED
|
@@ -186,6 +186,7 @@ Koharu uses multiple pretrained models, each tuned for a specific part of the pa
|
|
| 186 |
|
| 187 |
These models find text regions, speech bubbles, and page structure.
|
| 188 |
|
|
|
|
| 189 |
- [comic-text-bubble-detector](https://huggingface.co/ogkalu/comic-text-and-bubble-detector) for joint text block and speech bubble detection
|
| 190 |
- [comic-text-detector](https://huggingface.co/mayocream/comic-text-detector) for text segmentation masks
|
| 191 |
- [PP-DocLayoutV3](https://huggingface.co/PaddlePaddle/PP-DocLayoutV3_safetensors) for document layout analysis
|
|
|
|
| 186 |
|
| 187 |
These models find text regions, speech bubbles, and page structure.
|
| 188 |
|
| 189 |
+
- [anime-text-yolo](https://huggingface.co/mayocream/anime-text-yolo) for text block detection
|
| 190 |
- [comic-text-bubble-detector](https://huggingface.co/ogkalu/comic-text-and-bubble-detector) for joint text block and speech bubble detection
|
| 191 |
- [comic-text-detector](https://huggingface.co/mayocream/comic-text-detector) for text segmentation masks
|
| 192 |
- [PP-DocLayoutV3](https://huggingface.co/PaddlePaddle/PP-DocLayoutV3_safetensors) for document layout analysis
|
koharu-app/src/config.rs
CHANGED
|
@@ -83,7 +83,7 @@ pub struct PipelineConfig {
|
|
| 83 |
impl Default for PipelineConfig {
|
| 84 |
fn default() -> Self {
|
| 85 |
Self {
|
| 86 |
-
detector: "
|
| 87 |
font_detector: "yuzumarker-font-detection".to_string(),
|
| 88 |
segmenter: "comic-text-detector-seg".to_string(),
|
| 89 |
bubble_segmenter: "speech-bubble-segmentation".to_string(),
|
|
|
|
| 83 |
impl Default for PipelineConfig {
|
| 84 |
fn default() -> Self {
|
| 85 |
Self {
|
| 86 |
+
detector: "anime-text".to_string(),
|
| 87 |
font_detector: "yuzumarker-font-detection".to_string(),
|
| 88 |
segmenter: "comic-text-detector-seg".to_string(),
|
| 89 |
bubble_segmenter: "speech-bubble-segmentation".to_string(),
|
koharu-app/src/pipeline/engines/anime_text.rs
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//! Anime Text YOLO detector. Emits `AddNode` ops for each detected text region.
|
| 2 |
+
|
| 3 |
+
use anyhow::Result;
|
| 4 |
+
use async_trait::async_trait;
|
| 5 |
+
use koharu_core::{Op, TextData};
|
| 6 |
+
use koharu_ml::anime_text::AnimeTextDetector;
|
| 7 |
+
|
| 8 |
+
use crate::pipeline::artifacts::Artifact;
|
| 9 |
+
use crate::pipeline::engine::{Engine, EngineCtx, EngineInfo};
|
| 10 |
+
use crate::pipeline::engines::support::{
|
| 11 |
+
clear_text_nodes_ops, load_source_image, new_text_node, page_node_count,
|
| 12 |
+
sort_manga_reading_order, text_region_to_pair,
|
| 13 |
+
};
|
| 14 |
+
|
| 15 |
+
const DETECTOR_NAME: &str = "anime-text";
|
| 16 |
+
|
| 17 |
+
pub struct Model(AnimeTextDetector);
|
| 18 |
+
|
| 19 |
+
#[async_trait]
|
| 20 |
+
impl Engine for Model {
|
| 21 |
+
async fn run(&self, ctx: EngineCtx<'_>) -> Result<Vec<Op>> {
|
| 22 |
+
let image = load_source_image(ctx.scene, ctx.page, ctx.blobs)?;
|
| 23 |
+
let det = self.0.inference(&image)?;
|
| 24 |
+
|
| 25 |
+
let mut pairs: Vec<([f32; 4], TextData)> = det
|
| 26 |
+
.text_blocks
|
| 27 |
+
.into_iter()
|
| 28 |
+
.map(|r| text_region_to_pair(r, DETECTOR_NAME))
|
| 29 |
+
.collect();
|
| 30 |
+
sort_manga_reading_order(&mut pairs);
|
| 31 |
+
|
| 32 |
+
let mut ops = clear_text_nodes_ops(ctx.scene, ctx.page);
|
| 33 |
+
let removed = ops.len();
|
| 34 |
+
let insertion_start = page_node_count(ctx.scene, ctx.page).saturating_sub(removed);
|
| 35 |
+
ops.reserve(pairs.len());
|
| 36 |
+
for (at, (bbox, text)) in (insertion_start..).zip(pairs) {
|
| 37 |
+
let node = new_text_node(bbox, text);
|
| 38 |
+
ops.push(Op::AddNode {
|
| 39 |
+
page: ctx.page,
|
| 40 |
+
node,
|
| 41 |
+
at,
|
| 42 |
+
});
|
| 43 |
+
}
|
| 44 |
+
Ok(ops)
|
| 45 |
+
}
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
inventory::submit! {
|
| 49 |
+
EngineInfo {
|
| 50 |
+
id: "anime-text",
|
| 51 |
+
name: "Anime Text YOLO (N)",
|
| 52 |
+
needs: &[],
|
| 53 |
+
produces: &[Artifact::TextBoxes],
|
| 54 |
+
load: |runtime, cpu| Box::pin(async move {
|
| 55 |
+
let m = AnimeTextDetector::load(runtime, cpu).await?;
|
| 56 |
+
Ok(Box::new(Model(m)) as Box<dyn Engine>)
|
| 57 |
+
}),
|
| 58 |
+
}
|
| 59 |
+
}
|
koharu-app/src/pipeline/engines/mod.rs
CHANGED
|
@@ -4,6 +4,7 @@
|
|
| 4 |
//! `inventory::submit! { EngineInfo { … } }`. The registry picks them up
|
| 5 |
//! automatically at link time.
|
| 6 |
|
|
|
|
| 7 |
pub mod aot;
|
| 8 |
pub mod bubble_segmentation;
|
| 9 |
pub mod comic_text_bubble;
|
|
|
|
| 4 |
//! `inventory::submit! { EngineInfo { … } }`. The registry picks them up
|
| 5 |
//! automatically at link time.
|
| 6 |
|
| 7 |
+
pub mod anime_text;
|
| 8 |
pub mod aot;
|
| 9 |
pub mod bubble_segmentation;
|
| 10 |
pub mod comic_text_bubble;
|
koharu-app/src/pipeline/mod.rs
CHANGED
|
@@ -350,3 +350,19 @@ pub fn catalog() -> EngineCatalog {
|
|
| 350 |
.collect(),
|
| 351 |
}
|
| 352 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
.collect(),
|
| 351 |
}
|
| 352 |
}
|
| 353 |
+
|
| 354 |
+
#[cfg(test)]
|
| 355 |
+
mod tests {
|
| 356 |
+
use super::*;
|
| 357 |
+
|
| 358 |
+
#[test]
|
| 359 |
+
fn catalog_includes_anime_text_detector() {
|
| 360 |
+
let catalog = catalog();
|
| 361 |
+
|
| 362 |
+
assert!(catalog.detectors.iter().any(|engine| {
|
| 363 |
+
engine.id == "anime-text"
|
| 364 |
+
&& engine.name == "Anime Text YOLO (N)"
|
| 365 |
+
&& engine.produces.iter().map(String::as_str).eq(["TextBoxes"])
|
| 366 |
+
}));
|
| 367 |
+
}
|
| 368 |
+
}
|
koharu-ml/Cargo.toml
CHANGED
|
@@ -102,6 +102,10 @@ path = "bin/manga-text-segmentation-2025.rs"
|
|
| 102 |
name = "speech-bubble-segmentation"
|
| 103 |
path = "bin/speech-bubble-segmentation.rs"
|
| 104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
[[bin]]
|
| 106 |
name = "aot-inpainting"
|
| 107 |
path = "bin/aot-inpainting.rs"
|
|
|
|
| 102 |
name = "speech-bubble-segmentation"
|
| 103 |
path = "bin/speech-bubble-segmentation.rs"
|
| 104 |
|
| 105 |
+
[[bin]]
|
| 106 |
+
name = "anime-text"
|
| 107 |
+
path = "bin/anime-text.rs"
|
| 108 |
+
|
| 109 |
[[bin]]
|
| 110 |
name = "aot-inpainting"
|
| 111 |
path = "bin/aot-inpainting.rs"
|
koharu-ml/bin/anime-text.rs
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
use anyhow::{Result, anyhow, ensure};
|
| 2 |
+
use clap::{Parser, ValueEnum};
|
| 3 |
+
use imageproc::{drawing::draw_hollow_rect_mut, rect::Rect};
|
| 4 |
+
use koharu_ml::anime_text::{AnimeTextDetector, AnimeTextYoloVariant};
|
| 5 |
+
use koharu_runtime::{ComputePolicy, RuntimeManager, default_app_data_root};
|
| 6 |
+
use tokio::runtime::Builder;
|
| 7 |
+
|
| 8 |
+
#[path = "common.rs"]
|
| 9 |
+
mod common;
|
| 10 |
+
|
| 11 |
+
#[derive(Clone, Copy, Debug, ValueEnum)]
|
| 12 |
+
enum Variant {
|
| 13 |
+
N,
|
| 14 |
+
S,
|
| 15 |
+
M,
|
| 16 |
+
L,
|
| 17 |
+
X,
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
impl From<Variant> for AnimeTextYoloVariant {
|
| 21 |
+
fn from(value: Variant) -> Self {
|
| 22 |
+
match value {
|
| 23 |
+
Variant::N => Self::N,
|
| 24 |
+
Variant::S => Self::S,
|
| 25 |
+
Variant::M => Self::M,
|
| 26 |
+
Variant::L => Self::L,
|
| 27 |
+
Variant::X => Self::X,
|
| 28 |
+
}
|
| 29 |
+
}
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
#[derive(Parser)]
|
| 33 |
+
struct Cli {
|
| 34 |
+
#[arg(short, long, value_name = "FILE")]
|
| 35 |
+
input: String,
|
| 36 |
+
|
| 37 |
+
#[arg(short, long, value_name = "FILE")]
|
| 38 |
+
output: String,
|
| 39 |
+
|
| 40 |
+
#[arg(long, value_name = "FILE")]
|
| 41 |
+
json_output: Option<String>,
|
| 42 |
+
|
| 43 |
+
#[arg(long, value_enum, default_value_t = Variant::N)]
|
| 44 |
+
variant: Variant,
|
| 45 |
+
|
| 46 |
+
#[arg(long, default_value_t = 0.25)]
|
| 47 |
+
confidence_threshold: f32,
|
| 48 |
+
|
| 49 |
+
#[arg(long, default_value_t = 0.45)]
|
| 50 |
+
nms_threshold: f32,
|
| 51 |
+
|
| 52 |
+
#[arg(long, default_value_t = false)]
|
| 53 |
+
cpu: bool,
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
fn main() -> Result<()> {
|
| 57 |
+
common::init_tracing();
|
| 58 |
+
|
| 59 |
+
std::thread::Builder::new()
|
| 60 |
+
.name("anime-text-yolo".to_string())
|
| 61 |
+
.stack_size(64 * 1024 * 1024)
|
| 62 |
+
.spawn(|| {
|
| 63 |
+
let runtime = Builder::new_current_thread().enable_all().build()?;
|
| 64 |
+
runtime.block_on(async_main())
|
| 65 |
+
})?
|
| 66 |
+
.join()
|
| 67 |
+
.map_err(|_| anyhow!("anime-text-yolo thread panicked"))?
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
async fn async_main() -> Result<()> {
|
| 71 |
+
let cli = Cli::parse();
|
| 72 |
+
let variant = AnimeTextYoloVariant::from(cli.variant);
|
| 73 |
+
|
| 74 |
+
let runtime = RuntimeManager::new(
|
| 75 |
+
default_app_data_root(),
|
| 76 |
+
if cli.cpu {
|
| 77 |
+
ComputePolicy::CpuOnly
|
| 78 |
+
} else {
|
| 79 |
+
ComputePolicy::PreferGpu
|
| 80 |
+
},
|
| 81 |
+
)?;
|
| 82 |
+
runtime.prepare().await?;
|
| 83 |
+
|
| 84 |
+
let model = AnimeTextDetector::load_variant(&runtime, variant, cli.cpu).await?;
|
| 85 |
+
let bytes = std::fs::read(&cli.input)?;
|
| 86 |
+
let format = image::guess_format(&bytes)?;
|
| 87 |
+
let image = image::load_from_memory_with_format(&bytes, format)?;
|
| 88 |
+
let detection =
|
| 89 |
+
model.inference_with_thresholds(&image, cli.confidence_threshold, cli.nms_threshold)?;
|
| 90 |
+
|
| 91 |
+
ensure!(
|
| 92 |
+
!detection.regions.is_empty(),
|
| 93 |
+
"No anime text blocks detected in the image."
|
| 94 |
+
);
|
| 95 |
+
|
| 96 |
+
let mut image = image.to_rgba8();
|
| 97 |
+
for region in &detection.regions {
|
| 98 |
+
let width = (region.bbox[2] - region.bbox[0]).max(1.0) as u32;
|
| 99 |
+
let height = (region.bbox[3] - region.bbox[1]).max(1.0) as u32;
|
| 100 |
+
draw_hollow_rect_mut(
|
| 101 |
+
&mut image,
|
| 102 |
+
Rect::at(region.bbox[0] as i32, region.bbox[1] as i32).of_size(width, height),
|
| 103 |
+
image::Rgba([255, 0, 0, 255]),
|
| 104 |
+
);
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
image::DynamicImage::ImageRgba8(image).save(&cli.output)?;
|
| 108 |
+
if let Some(path) = &cli.json_output {
|
| 109 |
+
std::fs::write(path, serde_json::to_vec_pretty(&detection)?)?;
|
| 110 |
+
}
|
| 111 |
+
Ok(())
|
| 112 |
+
}
|
koharu-ml/src/anime_text/mod.rs
ADDED
|
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
mod model;
|
| 2 |
+
|
| 3 |
+
use std::{path::Path, path::PathBuf, time::Instant};
|
| 4 |
+
|
| 5 |
+
use anyhow::{Context, Result, bail};
|
| 6 |
+
use candle_core::{DType, Device, IndexOp, Tensor};
|
| 7 |
+
use candle_transformers::object_detection::{Bbox, non_maximum_suppression};
|
| 8 |
+
use image::{
|
| 9 |
+
DynamicImage, Rgb, RgbImage,
|
| 10 |
+
imageops::{self, FilterType},
|
| 11 |
+
};
|
| 12 |
+
use koharu_runtime::RuntimeManager;
|
| 13 |
+
use serde::{Deserialize, Serialize};
|
| 14 |
+
use tracing::instrument;
|
| 15 |
+
|
| 16 |
+
use crate::{device, loading, types::TextRegion};
|
| 17 |
+
|
| 18 |
+
use self::model::{Yolo12, Yolo12Scale};
|
| 19 |
+
|
| 20 |
+
pub const HF_REPO: &str = "mayocream/anime-text-yolo";
|
| 21 |
+
const INPUT_SIZE: u32 = 640;
|
| 22 |
+
const NUM_CLASSES: usize = 1;
|
| 23 |
+
const DEFAULT_VARIANT: AnimeTextYoloVariant = AnimeTextYoloVariant::N;
|
| 24 |
+
const DEFAULT_CONFIDENCE_THRESHOLD: f32 = 0.25;
|
| 25 |
+
const DEFAULT_NMS_THRESHOLD: f32 = 0.45;
|
| 26 |
+
const LETTERBOX_COLOR: u8 = 114;
|
| 27 |
+
const DETECTOR_NAME: &str = "anime-text-yolo";
|
| 28 |
+
const CLASS_NAMES: [&str; NUM_CLASSES] = ["text_block"];
|
| 29 |
+
|
| 30 |
+
koharu_runtime::declare_hf_model_package!(
|
| 31 |
+
id: "model:anime-text-yolo:yolo12n",
|
| 32 |
+
repo: HF_REPO,
|
| 33 |
+
file: "yolo12n_animetext.safetensors",
|
| 34 |
+
bootstrap: false,
|
| 35 |
+
order: 118,
|
| 36 |
+
);
|
| 37 |
+
koharu_runtime::declare_hf_model_package!(
|
| 38 |
+
id: "model:anime-text-yolo:yolo12s",
|
| 39 |
+
repo: HF_REPO,
|
| 40 |
+
file: "yolo12s_animetext.safetensors",
|
| 41 |
+
bootstrap: false,
|
| 42 |
+
order: 119,
|
| 43 |
+
);
|
| 44 |
+
koharu_runtime::declare_hf_model_package!(
|
| 45 |
+
id: "model:anime-text-yolo:yolo12m",
|
| 46 |
+
repo: HF_REPO,
|
| 47 |
+
file: "yolo12m_animetext.safetensors",
|
| 48 |
+
bootstrap: false,
|
| 49 |
+
order: 120,
|
| 50 |
+
);
|
| 51 |
+
koharu_runtime::declare_hf_model_package!(
|
| 52 |
+
id: "model:anime-text-yolo:yolo12l",
|
| 53 |
+
repo: HF_REPO,
|
| 54 |
+
file: "yolo12l_animetext.safetensors",
|
| 55 |
+
bootstrap: false,
|
| 56 |
+
order: 121,
|
| 57 |
+
);
|
| 58 |
+
koharu_runtime::declare_hf_model_package!(
|
| 59 |
+
id: "model:anime-text-yolo:yolo12x",
|
| 60 |
+
repo: HF_REPO,
|
| 61 |
+
file: "yolo12x_animetext.safetensors",
|
| 62 |
+
bootstrap: false,
|
| 63 |
+
order: 122,
|
| 64 |
+
);
|
| 65 |
+
|
| 66 |
+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
| 67 |
+
#[serde(rename_all = "lowercase")]
|
| 68 |
+
pub enum AnimeTextYoloVariant {
|
| 69 |
+
N,
|
| 70 |
+
S,
|
| 71 |
+
M,
|
| 72 |
+
L,
|
| 73 |
+
X,
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
impl AnimeTextYoloVariant {
|
| 77 |
+
pub fn filename(self) -> &'static str {
|
| 78 |
+
match self {
|
| 79 |
+
Self::N => "yolo12n_animetext.safetensors",
|
| 80 |
+
Self::S => "yolo12s_animetext.safetensors",
|
| 81 |
+
Self::M => "yolo12m_animetext.safetensors",
|
| 82 |
+
Self::L => "yolo12l_animetext.safetensors",
|
| 83 |
+
Self::X => "yolo12x_animetext.safetensors",
|
| 84 |
+
}
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
pub fn as_str(self) -> &'static str {
|
| 88 |
+
match self {
|
| 89 |
+
Self::N => "n",
|
| 90 |
+
Self::S => "s",
|
| 91 |
+
Self::M => "m",
|
| 92 |
+
Self::L => "l",
|
| 93 |
+
Self::X => "x",
|
| 94 |
+
}
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
fn scale(self) -> Yolo12Scale {
|
| 98 |
+
match self {
|
| 99 |
+
Self::N => Yolo12Scale::N,
|
| 100 |
+
Self::S => Yolo12Scale::S,
|
| 101 |
+
Self::M => Yolo12Scale::M,
|
| 102 |
+
Self::L => Yolo12Scale::L,
|
| 103 |
+
Self::X => Yolo12Scale::X,
|
| 104 |
+
}
|
| 105 |
+
}
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
impl std::fmt::Display for AnimeTextYoloVariant {
|
| 109 |
+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
| 110 |
+
f.write_str(self.as_str())
|
| 111 |
+
}
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
#[derive(Debug)]
|
| 115 |
+
pub struct AnimeTextDetector {
|
| 116 |
+
model: Yolo12,
|
| 117 |
+
variant: AnimeTextYoloVariant,
|
| 118 |
+
device: Device,
|
| 119 |
+
dtype: DType,
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
#[derive(Debug, Clone)]
|
| 123 |
+
struct PreparedInput {
|
| 124 |
+
pixel_values: Tensor,
|
| 125 |
+
original_width: u32,
|
| 126 |
+
original_height: u32,
|
| 127 |
+
pad_x: u32,
|
| 128 |
+
pad_y: u32,
|
| 129 |
+
scale: f32,
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
#[derive(Debug, Clone, Serialize)]
|
| 133 |
+
#[serde(rename_all = "camelCase")]
|
| 134 |
+
pub struct AnimeTextDetection {
|
| 135 |
+
pub image_width: u32,
|
| 136 |
+
pub image_height: u32,
|
| 137 |
+
pub variant: AnimeTextYoloVariant,
|
| 138 |
+
pub regions: Vec<AnimeTextRegion>,
|
| 139 |
+
pub text_blocks: Vec<TextRegion>,
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
#[derive(Debug, Clone, Serialize)]
|
| 143 |
+
#[serde(rename_all = "camelCase")]
|
| 144 |
+
pub struct AnimeTextRegion {
|
| 145 |
+
pub label_id: usize,
|
| 146 |
+
pub label: String,
|
| 147 |
+
pub score: f32,
|
| 148 |
+
pub bbox: [f32; 4],
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
impl AnimeTextDetector {
|
| 152 |
+
pub async fn load(runtime: &RuntimeManager, cpu: bool) -> Result<Self> {
|
| 153 |
+
Self::load_variant(runtime, DEFAULT_VARIANT, cpu).await
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
pub async fn load_variant(
|
| 157 |
+
runtime: &RuntimeManager,
|
| 158 |
+
variant: AnimeTextYoloVariant,
|
| 159 |
+
cpu: bool,
|
| 160 |
+
) -> Result<Self> {
|
| 161 |
+
let weights_path = resolve_model_path(runtime, variant).await?;
|
| 162 |
+
Self::load_from_path(weights_path, variant, cpu)
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
pub fn load_from_path(
|
| 166 |
+
weights_path: impl AsRef<Path>,
|
| 167 |
+
variant: AnimeTextYoloVariant,
|
| 168 |
+
cpu: bool,
|
| 169 |
+
) -> Result<Self> {
|
| 170 |
+
let device = device(cpu)?;
|
| 171 |
+
let dtype = loading::model_dtype(&device);
|
| 172 |
+
let model = loading::load_mmaped_safetensors_path_with_dtype(
|
| 173 |
+
weights_path.as_ref(),
|
| 174 |
+
&device,
|
| 175 |
+
dtype,
|
| 176 |
+
|vb| Yolo12::load(vb, variant.scale(), NUM_CLASSES),
|
| 177 |
+
)
|
| 178 |
+
.with_context(|| {
|
| 179 |
+
format!(
|
| 180 |
+
"failed to load anime text YOLO {} weights from {}",
|
| 181 |
+
variant,
|
| 182 |
+
weights_path.as_ref().display()
|
| 183 |
+
)
|
| 184 |
+
})?;
|
| 185 |
+
|
| 186 |
+
Ok(Self {
|
| 187 |
+
model,
|
| 188 |
+
variant,
|
| 189 |
+
device,
|
| 190 |
+
dtype,
|
| 191 |
+
})
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
pub fn variant(&self) -> AnimeTextYoloVariant {
|
| 195 |
+
self.variant
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
#[instrument(level = "debug", skip_all)]
|
| 199 |
+
pub fn inference(&self, image: &DynamicImage) -> Result<AnimeTextDetection> {
|
| 200 |
+
self.inference_with_thresholds(image, DEFAULT_CONFIDENCE_THRESHOLD, DEFAULT_NMS_THRESHOLD)
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
#[instrument(level = "debug", skip_all)]
|
| 204 |
+
pub fn inference_with_thresholds(
|
| 205 |
+
&self,
|
| 206 |
+
image: &DynamicImage,
|
| 207 |
+
confidence_threshold: f32,
|
| 208 |
+
nms_threshold: f32,
|
| 209 |
+
) -> Result<AnimeTextDetection> {
|
| 210 |
+
let started = Instant::now();
|
| 211 |
+
let prepared = self.preprocess(image)?;
|
| 212 |
+
let outputs = self.model.forward(&prepared.pixel_values)?;
|
| 213 |
+
let regions = postprocess(&outputs, &prepared, confidence_threshold, nms_threshold)?;
|
| 214 |
+
let text_blocks = regions_to_text_blocks(®ions);
|
| 215 |
+
|
| 216 |
+
tracing::info!(
|
| 217 |
+
width = image.width(),
|
| 218 |
+
height = image.height(),
|
| 219 |
+
variant = %self.variant,
|
| 220 |
+
detections = regions.len(),
|
| 221 |
+
total_ms = started.elapsed().as_millis(),
|
| 222 |
+
"anime text YOLO timings"
|
| 223 |
+
);
|
| 224 |
+
|
| 225 |
+
Ok(AnimeTextDetection {
|
| 226 |
+
image_width: prepared.original_width,
|
| 227 |
+
image_height: prepared.original_height,
|
| 228 |
+
variant: self.variant,
|
| 229 |
+
regions,
|
| 230 |
+
text_blocks,
|
| 231 |
+
})
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
fn preprocess(&self, image: &DynamicImage) -> Result<PreparedInput> {
|
| 235 |
+
let rgb = image.to_rgb8();
|
| 236 |
+
let (original_width, original_height) = rgb.dimensions();
|
| 237 |
+
let scale = f32::min(
|
| 238 |
+
INPUT_SIZE as f32 / original_width.max(1) as f32,
|
| 239 |
+
INPUT_SIZE as f32 / original_height.max(1) as f32,
|
| 240 |
+
);
|
| 241 |
+
let resized_width = ((original_width as f32 * scale).round() as u32).clamp(1, INPUT_SIZE);
|
| 242 |
+
let resized_height = ((original_height as f32 * scale).round() as u32).clamp(1, INPUT_SIZE);
|
| 243 |
+
let pad_x = (INPUT_SIZE - resized_width) / 2;
|
| 244 |
+
let pad_y = (INPUT_SIZE - resized_height) / 2;
|
| 245 |
+
|
| 246 |
+
let resized = if resized_width == original_width && resized_height == original_height {
|
| 247 |
+
rgb
|
| 248 |
+
} else {
|
| 249 |
+
imageops::resize(&rgb, resized_width, resized_height, FilterType::Triangle)
|
| 250 |
+
};
|
| 251 |
+
|
| 252 |
+
let mut letterboxed =
|
| 253 |
+
RgbImage::from_pixel(INPUT_SIZE, INPUT_SIZE, Rgb([LETTERBOX_COLOR; 3]));
|
| 254 |
+
imageops::overlay(
|
| 255 |
+
&mut letterboxed,
|
| 256 |
+
&resized,
|
| 257 |
+
i64::from(pad_x),
|
| 258 |
+
i64::from(pad_y),
|
| 259 |
+
);
|
| 260 |
+
|
| 261 |
+
let pixel_values = Tensor::from_vec(
|
| 262 |
+
letterboxed.into_raw(),
|
| 263 |
+
(1, INPUT_SIZE as usize, INPUT_SIZE as usize, 3),
|
| 264 |
+
&self.device,
|
| 265 |
+
)?
|
| 266 |
+
.permute((0, 3, 1, 2))?
|
| 267 |
+
.to_dtype(self.dtype)?;
|
| 268 |
+
let pixel_values = (pixel_values * (1.0 / 255.0))?;
|
| 269 |
+
|
| 270 |
+
Ok(PreparedInput {
|
| 271 |
+
pixel_values,
|
| 272 |
+
original_width,
|
| 273 |
+
original_height,
|
| 274 |
+
pad_x,
|
| 275 |
+
pad_y,
|
| 276 |
+
scale,
|
| 277 |
+
})
|
| 278 |
+
}
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
pub async fn prefetch(runtime: &RuntimeManager) -> Result<()> {
|
| 282 |
+
prefetch_variant(runtime, DEFAULT_VARIANT).await
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
pub async fn prefetch_variant(
|
| 286 |
+
runtime: &RuntimeManager,
|
| 287 |
+
variant: AnimeTextYoloVariant,
|
| 288 |
+
) -> Result<()> {
|
| 289 |
+
let _ = resolve_model_path(runtime, variant).await?;
|
| 290 |
+
Ok(())
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
async fn resolve_model_path(
|
| 294 |
+
runtime: &RuntimeManager,
|
| 295 |
+
variant: AnimeTextYoloVariant,
|
| 296 |
+
) -> Result<PathBuf> {
|
| 297 |
+
runtime
|
| 298 |
+
.downloads()
|
| 299 |
+
.huggingface_model(HF_REPO, variant.filename())
|
| 300 |
+
.await
|
| 301 |
+
.with_context(|| format!("failed to download {} from {}", variant.filename(), HF_REPO))
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
fn postprocess(
|
| 305 |
+
outputs: &Tensor,
|
| 306 |
+
prepared: &PreparedInput,
|
| 307 |
+
confidence_threshold: f32,
|
| 308 |
+
nms_threshold: f32,
|
| 309 |
+
) -> Result<Vec<AnimeTextRegion>> {
|
| 310 |
+
let pred = outputs
|
| 311 |
+
.to_dtype(DType::F32)?
|
| 312 |
+
.to_device(&Device::Cpu)?
|
| 313 |
+
.i(0)?;
|
| 314 |
+
let (channels, anchors) = pred.dims2()?;
|
| 315 |
+
let expected_channels = 4 + NUM_CLASSES;
|
| 316 |
+
if channels != expected_channels {
|
| 317 |
+
bail!(
|
| 318 |
+
"unexpected anime text YOLO prediction channels {channels}, expected {expected_channels}"
|
| 319 |
+
);
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
let mut grouped: Vec<Vec<Bbox<usize>>> = (0..NUM_CLASSES).map(|_| Vec::new()).collect();
|
| 323 |
+
for anchor_idx in 0..anchors {
|
| 324 |
+
let values = pred.i((.., anchor_idx))?.to_vec1::<f32>()?;
|
| 325 |
+
let class_scores = &values[4..4 + NUM_CLASSES];
|
| 326 |
+
let Some((label_id, &score)) = class_scores
|
| 327 |
+
.iter()
|
| 328 |
+
.enumerate()
|
| 329 |
+
.max_by(|(_, a), (_, b)| a.total_cmp(b))
|
| 330 |
+
else {
|
| 331 |
+
continue;
|
| 332 |
+
};
|
| 333 |
+
if score < confidence_threshold {
|
| 334 |
+
continue;
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
let bbox = map_bbox_to_original(
|
| 338 |
+
[
|
| 339 |
+
values[0] - values[2] * 0.5,
|
| 340 |
+
values[1] - values[3] * 0.5,
|
| 341 |
+
values[0] + values[2] * 0.5,
|
| 342 |
+
values[1] + values[3] * 0.5,
|
| 343 |
+
],
|
| 344 |
+
prepared,
|
| 345 |
+
);
|
| 346 |
+
if bbox[2] <= bbox[0] || bbox[3] <= bbox[1] {
|
| 347 |
+
continue;
|
| 348 |
+
}
|
| 349 |
+
|
| 350 |
+
grouped[label_id].push(Bbox {
|
| 351 |
+
xmin: bbox[0],
|
| 352 |
+
ymin: bbox[1],
|
| 353 |
+
xmax: bbox[2],
|
| 354 |
+
ymax: bbox[3],
|
| 355 |
+
confidence: score,
|
| 356 |
+
data: label_id,
|
| 357 |
+
});
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
non_maximum_suppression(&mut grouped, nms_threshold);
|
| 361 |
+
|
| 362 |
+
let mut regions = Vec::new();
|
| 363 |
+
for (label_id, bboxes) in grouped.into_iter().enumerate() {
|
| 364 |
+
let label = CLASS_NAMES
|
| 365 |
+
.get(label_id)
|
| 366 |
+
.copied()
|
| 367 |
+
.unwrap_or("text_block")
|
| 368 |
+
.to_string();
|
| 369 |
+
for bbox in bboxes {
|
| 370 |
+
regions.push(AnimeTextRegion {
|
| 371 |
+
label_id,
|
| 372 |
+
label: label.clone(),
|
| 373 |
+
score: bbox.confidence,
|
| 374 |
+
bbox: [bbox.xmin, bbox.ymin, bbox.xmax, bbox.ymax],
|
| 375 |
+
});
|
| 376 |
+
}
|
| 377 |
+
}
|
| 378 |
+
regions.sort_by(|a, b| b.score.total_cmp(&a.score));
|
| 379 |
+
Ok(regions)
|
| 380 |
+
}
|
| 381 |
+
|
| 382 |
+
fn map_bbox_to_original(bbox: [f32; 4], prepared: &PreparedInput) -> [f32; 4] {
|
| 383 |
+
let width = prepared.original_width as f32;
|
| 384 |
+
let height = prepared.original_height as f32;
|
| 385 |
+
let pad_x = prepared.pad_x as f32;
|
| 386 |
+
let pad_y = prepared.pad_y as f32;
|
| 387 |
+
[
|
| 388 |
+
((bbox[0] - pad_x) / prepared.scale).clamp(0.0, width),
|
| 389 |
+
((bbox[1] - pad_y) / prepared.scale).clamp(0.0, height),
|
| 390 |
+
((bbox[2] - pad_x) / prepared.scale).clamp(0.0, width),
|
| 391 |
+
((bbox[3] - pad_y) / prepared.scale).clamp(0.0, height),
|
| 392 |
+
]
|
| 393 |
+
}
|
| 394 |
+
|
| 395 |
+
fn regions_to_text_blocks(regions: &[AnimeTextRegion]) -> Vec<TextRegion> {
|
| 396 |
+
regions
|
| 397 |
+
.iter()
|
| 398 |
+
.filter_map(|region| {
|
| 399 |
+
let width = (region.bbox[2] - region.bbox[0]).max(0.0);
|
| 400 |
+
let height = (region.bbox[3] - region.bbox[1]).max(0.0);
|
| 401 |
+
if width <= 1.0 || height <= 1.0 {
|
| 402 |
+
return None;
|
| 403 |
+
}
|
| 404 |
+
Some(TextRegion {
|
| 405 |
+
x: region.bbox[0],
|
| 406 |
+
y: region.bbox[1],
|
| 407 |
+
width,
|
| 408 |
+
height,
|
| 409 |
+
confidence: region.score,
|
| 410 |
+
detector: Some(DETECTOR_NAME.to_string()),
|
| 411 |
+
..Default::default()
|
| 412 |
+
})
|
| 413 |
+
})
|
| 414 |
+
.collect()
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
+
#[cfg(test)]
|
| 418 |
+
mod tests {
|
| 419 |
+
use super::{PreparedInput, map_bbox_to_original};
|
| 420 |
+
use candle_core::{DType, Device, Tensor};
|
| 421 |
+
|
| 422 |
+
#[test]
|
| 423 |
+
fn map_bbox_to_original_removes_letterbox_padding() {
|
| 424 |
+
let prepared = PreparedInput {
|
| 425 |
+
pixel_values: Tensor::zeros((1, 3, 640, 640), DType::F32, &Device::Cpu)
|
| 426 |
+
.expect("tensor"),
|
| 427 |
+
original_width: 1000,
|
| 428 |
+
original_height: 500,
|
| 429 |
+
pad_x: 0,
|
| 430 |
+
pad_y: 160,
|
| 431 |
+
scale: 0.64,
|
| 432 |
+
};
|
| 433 |
+
|
| 434 |
+
let bbox = map_bbox_to_original([100.0, 200.0, 540.0, 440.0], &prepared);
|
| 435 |
+
assert!((bbox[0] - 156.25).abs() < 1e-3);
|
| 436 |
+
assert!((bbox[1] - 62.5).abs() < 1e-3);
|
| 437 |
+
assert!((bbox[2] - 843.75).abs() < 1e-3);
|
| 438 |
+
assert!((bbox[3] - 437.5).abs() < 1e-3);
|
| 439 |
+
}
|
| 440 |
+
}
|
koharu-ml/src/anime_text/model.rs
ADDED
|
@@ -0,0 +1,982 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
use candle_core::{D, IndexOp, Result, Tensor};
|
| 2 |
+
use candle_nn::{BatchNorm, Conv2d, Conv2dConfig, Module, ModuleT, VarBuilder, batch_norm};
|
| 3 |
+
|
| 4 |
+
use crate::ops::{conv2d, conv2d_no_bias};
|
| 5 |
+
|
| 6 |
+
const BN_EPS: f64 = 1e-3;
|
| 7 |
+
const REG_MAX: usize = 16;
|
| 8 |
+
|
| 9 |
+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
| 10 |
+
pub enum Yolo12Scale {
|
| 11 |
+
N,
|
| 12 |
+
S,
|
| 13 |
+
M,
|
| 14 |
+
L,
|
| 15 |
+
X,
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
#[derive(Debug, Clone, Copy)]
|
| 19 |
+
struct Multiples {
|
| 20 |
+
depth: f64,
|
| 21 |
+
width: f64,
|
| 22 |
+
max_channels: usize,
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
impl Yolo12Scale {
|
| 26 |
+
fn multiples(self) -> Multiples {
|
| 27 |
+
match self {
|
| 28 |
+
Self::N => Multiples {
|
| 29 |
+
depth: 0.50,
|
| 30 |
+
width: 0.25,
|
| 31 |
+
max_channels: 1024,
|
| 32 |
+
},
|
| 33 |
+
Self::S => Multiples {
|
| 34 |
+
depth: 0.50,
|
| 35 |
+
width: 0.50,
|
| 36 |
+
max_channels: 1024,
|
| 37 |
+
},
|
| 38 |
+
Self::M => Multiples {
|
| 39 |
+
depth: 0.50,
|
| 40 |
+
width: 1.00,
|
| 41 |
+
max_channels: 512,
|
| 42 |
+
},
|
| 43 |
+
Self::L => Multiples {
|
| 44 |
+
depth: 1.00,
|
| 45 |
+
width: 1.00,
|
| 46 |
+
max_channels: 512,
|
| 47 |
+
},
|
| 48 |
+
Self::X => Multiples {
|
| 49 |
+
depth: 1.00,
|
| 50 |
+
width: 1.50,
|
| 51 |
+
max_channels: 512,
|
| 52 |
+
},
|
| 53 |
+
}
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
fn uses_large_c3k(self) -> bool {
|
| 57 |
+
matches!(self, Self::M | Self::L | Self::X)
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
fn uses_a2_residual(self) -> bool {
|
| 61 |
+
matches!(self, Self::L | Self::X)
|
| 62 |
+
}
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
impl Multiples {
|
| 66 |
+
fn channels(&self, base: usize) -> usize {
|
| 67 |
+
make_divisible((base.min(self.max_channels) as f64) * self.width, 8)
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
fn repeats(&self, base: usize) -> usize {
|
| 71 |
+
if base > 1 {
|
| 72 |
+
((base as f64 * self.depth).round() as usize).max(1)
|
| 73 |
+
} else {
|
| 74 |
+
base
|
| 75 |
+
}
|
| 76 |
+
}
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
fn make_divisible(value: f64, divisor: usize) -> usize {
|
| 80 |
+
((value / divisor as f64).ceil() as usize) * divisor
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
#[derive(Debug)]
|
| 84 |
+
struct Upsample {
|
| 85 |
+
scale_factor: usize,
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
impl Upsample {
|
| 89 |
+
fn new(scale_factor: usize) -> Self {
|
| 90 |
+
Self { scale_factor }
|
| 91 |
+
}
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
impl Module for Upsample {
|
| 95 |
+
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
| 96 |
+
let (_, _, h, w) = xs.dims4()?;
|
| 97 |
+
xs.upsample_nearest2d(self.scale_factor * h, self.scale_factor * w)
|
| 98 |
+
}
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
#[derive(Debug)]
|
| 102 |
+
struct ConvBlock {
|
| 103 |
+
conv: Conv2d,
|
| 104 |
+
bn: BatchNorm,
|
| 105 |
+
activation: bool,
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
impl ConvBlock {
|
| 109 |
+
#[allow(clippy::too_many_arguments)]
|
| 110 |
+
fn load(
|
| 111 |
+
vb: VarBuilder,
|
| 112 |
+
in_channels: usize,
|
| 113 |
+
out_channels: usize,
|
| 114 |
+
kernel_size: usize,
|
| 115 |
+
stride: usize,
|
| 116 |
+
padding: Option<usize>,
|
| 117 |
+
groups: usize,
|
| 118 |
+
activation: bool,
|
| 119 |
+
) -> Result<Self> {
|
| 120 |
+
let cfg = Conv2dConfig {
|
| 121 |
+
padding: padding.unwrap_or(kernel_size / 2),
|
| 122 |
+
stride,
|
| 123 |
+
groups,
|
| 124 |
+
dilation: 1,
|
| 125 |
+
cudnn_fwd_algo: None,
|
| 126 |
+
};
|
| 127 |
+
Ok(Self {
|
| 128 |
+
conv: conv2d_no_bias(in_channels, out_channels, kernel_size, cfg, vb.pp("conv"))?,
|
| 129 |
+
bn: batch_norm(out_channels, BN_EPS, vb.pp("bn"))?,
|
| 130 |
+
activation,
|
| 131 |
+
})
|
| 132 |
+
}
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
impl Module for ConvBlock {
|
| 136 |
+
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
| 137 |
+
let xs = self.conv.forward(xs)?;
|
| 138 |
+
let xs = self.bn.forward_t(&xs, false)?;
|
| 139 |
+
if self.activation {
|
| 140 |
+
candle_nn::ops::silu(&xs)
|
| 141 |
+
} else {
|
| 142 |
+
Ok(xs)
|
| 143 |
+
}
|
| 144 |
+
}
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
#[derive(Debug)]
|
| 148 |
+
struct Bottleneck {
|
| 149 |
+
cv1: ConvBlock,
|
| 150 |
+
cv2: ConvBlock,
|
| 151 |
+
residual: bool,
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
impl Bottleneck {
|
| 155 |
+
fn load(
|
| 156 |
+
vb: VarBuilder,
|
| 157 |
+
in_channels: usize,
|
| 158 |
+
out_channels: usize,
|
| 159 |
+
shortcut: bool,
|
| 160 |
+
groups: usize,
|
| 161 |
+
kernel_size: usize,
|
| 162 |
+
expansion: f64,
|
| 163 |
+
) -> Result<Self> {
|
| 164 |
+
let hidden = (out_channels as f64 * expansion) as usize;
|
| 165 |
+
Ok(Self {
|
| 166 |
+
cv1: ConvBlock::load(
|
| 167 |
+
vb.pp("cv1"),
|
| 168 |
+
in_channels,
|
| 169 |
+
hidden,
|
| 170 |
+
kernel_size,
|
| 171 |
+
1,
|
| 172 |
+
None,
|
| 173 |
+
1,
|
| 174 |
+
true,
|
| 175 |
+
)?,
|
| 176 |
+
cv2: ConvBlock::load(
|
| 177 |
+
vb.pp("cv2"),
|
| 178 |
+
hidden,
|
| 179 |
+
out_channels,
|
| 180 |
+
kernel_size,
|
| 181 |
+
1,
|
| 182 |
+
None,
|
| 183 |
+
groups,
|
| 184 |
+
true,
|
| 185 |
+
)?,
|
| 186 |
+
residual: shortcut && in_channels == out_channels,
|
| 187 |
+
})
|
| 188 |
+
}
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
impl Module for Bottleneck {
|
| 192 |
+
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
| 193 |
+
let ys = self.cv2.forward(&self.cv1.forward(xs)?)?;
|
| 194 |
+
if self.residual { xs + ys } else { Ok(ys) }
|
| 195 |
+
}
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
#[derive(Debug)]
|
| 199 |
+
struct C3k {
|
| 200 |
+
cv1: ConvBlock,
|
| 201 |
+
cv2: ConvBlock,
|
| 202 |
+
cv3: ConvBlock,
|
| 203 |
+
blocks: Vec<Bottleneck>,
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
#[derive(Debug, Clone, Copy)]
|
| 207 |
+
struct C3kOptions {
|
| 208 |
+
shortcut: bool,
|
| 209 |
+
groups: usize,
|
| 210 |
+
expansion: f64,
|
| 211 |
+
kernel_size: usize,
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
impl C3k {
|
| 215 |
+
fn load(
|
| 216 |
+
vb: VarBuilder,
|
| 217 |
+
in_channels: usize,
|
| 218 |
+
out_channels: usize,
|
| 219 |
+
repeats: usize,
|
| 220 |
+
options: C3kOptions,
|
| 221 |
+
) -> Result<Self> {
|
| 222 |
+
let hidden = (out_channels as f64 * options.expansion) as usize;
|
| 223 |
+
let mut blocks = Vec::with_capacity(repeats);
|
| 224 |
+
for index in 0..repeats {
|
| 225 |
+
blocks.push(Bottleneck::load(
|
| 226 |
+
vb.pp(format!("m.{index}")),
|
| 227 |
+
hidden,
|
| 228 |
+
hidden,
|
| 229 |
+
options.shortcut,
|
| 230 |
+
options.groups,
|
| 231 |
+
options.kernel_size,
|
| 232 |
+
1.0,
|
| 233 |
+
)?);
|
| 234 |
+
}
|
| 235 |
+
Ok(Self {
|
| 236 |
+
cv1: ConvBlock::load(vb.pp("cv1"), in_channels, hidden, 1, 1, None, 1, true)?,
|
| 237 |
+
cv2: ConvBlock::load(vb.pp("cv2"), in_channels, hidden, 1, 1, None, 1, true)?,
|
| 238 |
+
cv3: ConvBlock::load(vb.pp("cv3"), hidden * 2, out_channels, 1, 1, None, 1, true)?,
|
| 239 |
+
blocks,
|
| 240 |
+
})
|
| 241 |
+
}
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
impl Module for C3k {
|
| 245 |
+
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
| 246 |
+
let mut y1 = self.cv1.forward(xs)?;
|
| 247 |
+
for block in &self.blocks {
|
| 248 |
+
y1 = block.forward(&y1)?;
|
| 249 |
+
}
|
| 250 |
+
let y2 = self.cv2.forward(xs)?;
|
| 251 |
+
self.cv3.forward(&Tensor::cat(&[&y1, &y2], 1)?)
|
| 252 |
+
}
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
#[derive(Debug)]
|
| 256 |
+
enum C3k2Block {
|
| 257 |
+
Bottleneck(Box<Bottleneck>),
|
| 258 |
+
C3k(Box<C3k>),
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
impl Module for C3k2Block {
|
| 262 |
+
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
| 263 |
+
match self {
|
| 264 |
+
Self::Bottleneck(block) => block.forward(xs),
|
| 265 |
+
Self::C3k(block) => block.forward(xs),
|
| 266 |
+
}
|
| 267 |
+
}
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
#[derive(Debug)]
|
| 271 |
+
struct C3k2 {
|
| 272 |
+
cv1: ConvBlock,
|
| 273 |
+
cv2: ConvBlock,
|
| 274 |
+
blocks: Vec<C3k2Block>,
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
#[derive(Debug, Clone, Copy)]
|
| 278 |
+
struct C3k2Options {
|
| 279 |
+
use_c3k: bool,
|
| 280 |
+
expansion: f64,
|
| 281 |
+
groups: usize,
|
| 282 |
+
shortcut: bool,
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
impl C3k2 {
|
| 286 |
+
fn load(
|
| 287 |
+
vb: VarBuilder,
|
| 288 |
+
in_channels: usize,
|
| 289 |
+
out_channels: usize,
|
| 290 |
+
repeats: usize,
|
| 291 |
+
options: C3k2Options,
|
| 292 |
+
) -> Result<Self> {
|
| 293 |
+
let hidden = (out_channels as f64 * options.expansion) as usize;
|
| 294 |
+
let mut blocks = Vec::with_capacity(repeats);
|
| 295 |
+
for index in 0..repeats {
|
| 296 |
+
let vb = vb.pp(format!("m.{index}"));
|
| 297 |
+
let block = if options.use_c3k {
|
| 298 |
+
C3k2Block::C3k(Box::new(C3k::load(
|
| 299 |
+
vb,
|
| 300 |
+
hidden,
|
| 301 |
+
hidden,
|
| 302 |
+
2,
|
| 303 |
+
C3kOptions {
|
| 304 |
+
shortcut: options.shortcut,
|
| 305 |
+
groups: options.groups,
|
| 306 |
+
expansion: 0.5,
|
| 307 |
+
kernel_size: 3,
|
| 308 |
+
},
|
| 309 |
+
)?))
|
| 310 |
+
} else {
|
| 311 |
+
C3k2Block::Bottleneck(Box::new(Bottleneck::load(
|
| 312 |
+
vb,
|
| 313 |
+
hidden,
|
| 314 |
+
hidden,
|
| 315 |
+
options.shortcut,
|
| 316 |
+
options.groups,
|
| 317 |
+
3,
|
| 318 |
+
0.5,
|
| 319 |
+
)?))
|
| 320 |
+
};
|
| 321 |
+
blocks.push(block);
|
| 322 |
+
}
|
| 323 |
+
Ok(Self {
|
| 324 |
+
cv1: ConvBlock::load(vb.pp("cv1"), in_channels, hidden * 2, 1, 1, None, 1, true)?,
|
| 325 |
+
cv2: ConvBlock::load(
|
| 326 |
+
vb.pp("cv2"),
|
| 327 |
+
(2 + repeats) * hidden,
|
| 328 |
+
out_channels,
|
| 329 |
+
1,
|
| 330 |
+
1,
|
| 331 |
+
None,
|
| 332 |
+
1,
|
| 333 |
+
true,
|
| 334 |
+
)?,
|
| 335 |
+
blocks,
|
| 336 |
+
})
|
| 337 |
+
}
|
| 338 |
+
}
|
| 339 |
+
|
| 340 |
+
impl Module for C3k2 {
|
| 341 |
+
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
| 342 |
+
let mut ys = self.cv1.forward(xs)?.chunk(2, 1)?;
|
| 343 |
+
for block in &self.blocks {
|
| 344 |
+
ys.push(block.forward(ys.last().expect("c3k2 chunk"))?);
|
| 345 |
+
}
|
| 346 |
+
let refs = ys.iter().collect::<Vec<_>>();
|
| 347 |
+
self.cv2.forward(&Tensor::cat(&refs, 1)?)
|
| 348 |
+
}
|
| 349 |
+
}
|
| 350 |
+
|
| 351 |
+
#[derive(Debug)]
|
| 352 |
+
struct AreaAttention {
|
| 353 |
+
area: usize,
|
| 354 |
+
num_heads: usize,
|
| 355 |
+
head_dim: usize,
|
| 356 |
+
qkv: ConvBlock,
|
| 357 |
+
proj: ConvBlock,
|
| 358 |
+
pe: ConvBlock,
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
impl AreaAttention {
|
| 362 |
+
fn load(vb: VarBuilder, dim: usize, num_heads: usize, area: usize) -> Result<Self> {
|
| 363 |
+
let head_dim = dim / num_heads;
|
| 364 |
+
let all_head_dim = head_dim * num_heads;
|
| 365 |
+
Ok(Self {
|
| 366 |
+
area,
|
| 367 |
+
num_heads,
|
| 368 |
+
head_dim,
|
| 369 |
+
qkv: ConvBlock::load(vb.pp("qkv"), dim, all_head_dim * 3, 1, 1, None, 1, false)?,
|
| 370 |
+
proj: ConvBlock::load(vb.pp("proj"), all_head_dim, dim, 1, 1, None, 1, false)?,
|
| 371 |
+
pe: ConvBlock::load(vb.pp("pe"), all_head_dim, dim, 7, 1, Some(3), dim, false)?,
|
| 372 |
+
})
|
| 373 |
+
}
|
| 374 |
+
}
|
| 375 |
+
|
| 376 |
+
impl Module for AreaAttention {
|
| 377 |
+
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
| 378 |
+
let (batch, channels, height, width) = xs.dims4()?;
|
| 379 |
+
let num_tokens = height * width;
|
| 380 |
+
let qkv = self
|
| 381 |
+
.qkv
|
| 382 |
+
.forward(xs)?
|
| 383 |
+
.flatten_from(2)?
|
| 384 |
+
.transpose(1, 2)?
|
| 385 |
+
.contiguous()?;
|
| 386 |
+
let qkv = if self.area > 1 {
|
| 387 |
+
qkv.reshape((batch * self.area, num_tokens / self.area, channels * 3))?
|
| 388 |
+
} else {
|
| 389 |
+
qkv
|
| 390 |
+
};
|
| 391 |
+
let (area_batch, area_tokens, _) = qkv.dims3()?;
|
| 392 |
+
let qkv = qkv
|
| 393 |
+
.reshape((area_batch, area_tokens, self.num_heads, self.head_dim * 3))?
|
| 394 |
+
.permute((0, 2, 3, 1))?
|
| 395 |
+
.contiguous()?;
|
| 396 |
+
|
| 397 |
+
let q = qkv.narrow(2, 0, self.head_dim)?;
|
| 398 |
+
let k = qkv.narrow(2, self.head_dim, self.head_dim)?;
|
| 399 |
+
let v = qkv.narrow(2, self.head_dim * 2, self.head_dim)?;
|
| 400 |
+
let attn = (q.transpose(2, 3)?.matmul(&k)? * (self.head_dim as f64).powf(-0.5))?;
|
| 401 |
+
let attn = candle_nn::ops::softmax(&attn, D::Minus1)?;
|
| 402 |
+
let ys = v.matmul(&attn.transpose(2, 3)?)?;
|
| 403 |
+
let ys = ys.permute((0, 3, 1, 2))?.contiguous()?;
|
| 404 |
+
let v = v.permute((0, 3, 1, 2))?.contiguous()?;
|
| 405 |
+
|
| 406 |
+
let (ys, v) = if self.area > 1 {
|
| 407 |
+
(
|
| 408 |
+
ys.reshape((batch, num_tokens, channels))?,
|
| 409 |
+
v.reshape((batch, num_tokens, channels))?,
|
| 410 |
+
)
|
| 411 |
+
} else {
|
| 412 |
+
(ys, v)
|
| 413 |
+
};
|
| 414 |
+
|
| 415 |
+
let ys = ys
|
| 416 |
+
.reshape((batch, height, width, channels))?
|
| 417 |
+
.permute((0, 3, 1, 2))?
|
| 418 |
+
.contiguous()?;
|
| 419 |
+
let v = v
|
| 420 |
+
.reshape((batch, height, width, channels))?
|
| 421 |
+
.permute((0, 3, 1, 2))?
|
| 422 |
+
.contiguous()?;
|
| 423 |
+
self.proj.forward(&(ys + self.pe.forward(&v)?)?)
|
| 424 |
+
}
|
| 425 |
+
}
|
| 426 |
+
|
| 427 |
+
#[derive(Debug)]
|
| 428 |
+
struct AreaBlock {
|
| 429 |
+
attn: AreaAttention,
|
| 430 |
+
mlp0: ConvBlock,
|
| 431 |
+
mlp1: ConvBlock,
|
| 432 |
+
}
|
| 433 |
+
|
| 434 |
+
impl AreaBlock {
|
| 435 |
+
fn load(
|
| 436 |
+
vb: VarBuilder,
|
| 437 |
+
dim: usize,
|
| 438 |
+
num_heads: usize,
|
| 439 |
+
mlp_ratio: f64,
|
| 440 |
+
area: usize,
|
| 441 |
+
) -> Result<Self> {
|
| 442 |
+
let mlp_hidden = (dim as f64 * mlp_ratio) as usize;
|
| 443 |
+
Ok(Self {
|
| 444 |
+
attn: AreaAttention::load(vb.pp("attn"), dim, num_heads, area)?,
|
| 445 |
+
mlp0: ConvBlock::load(vb.pp("mlp.0"), dim, mlp_hidden, 1, 1, None, 1, true)?,
|
| 446 |
+
mlp1: ConvBlock::load(vb.pp("mlp.1"), mlp_hidden, dim, 1, 1, None, 1, false)?,
|
| 447 |
+
})
|
| 448 |
+
}
|
| 449 |
+
}
|
| 450 |
+
|
| 451 |
+
impl Module for AreaBlock {
|
| 452 |
+
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
| 453 |
+
let xs = (xs + self.attn.forward(xs)?)?;
|
| 454 |
+
let mlp = self.mlp1.forward(&self.mlp0.forward(&xs)?)?;
|
| 455 |
+
xs + mlp
|
| 456 |
+
}
|
| 457 |
+
}
|
| 458 |
+
|
| 459 |
+
#[derive(Debug)]
|
| 460 |
+
enum A2C2fBlock {
|
| 461 |
+
Attention(Vec<AreaBlock>),
|
| 462 |
+
C3k(Box<C3k>),
|
| 463 |
+
}
|
| 464 |
+
|
| 465 |
+
impl Module for A2C2fBlock {
|
| 466 |
+
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
| 467 |
+
match self {
|
| 468 |
+
Self::Attention(blocks) => {
|
| 469 |
+
let mut ys = xs.clone();
|
| 470 |
+
for block in blocks {
|
| 471 |
+
ys = block.forward(&ys)?;
|
| 472 |
+
}
|
| 473 |
+
Ok(ys)
|
| 474 |
+
}
|
| 475 |
+
Self::C3k(block) => block.forward(xs),
|
| 476 |
+
}
|
| 477 |
+
}
|
| 478 |
+
}
|
| 479 |
+
|
| 480 |
+
#[derive(Debug)]
|
| 481 |
+
struct A2C2f {
|
| 482 |
+
cv1: ConvBlock,
|
| 483 |
+
cv2: ConvBlock,
|
| 484 |
+
gamma: Option<Tensor>,
|
| 485 |
+
blocks: Vec<A2C2fBlock>,
|
| 486 |
+
}
|
| 487 |
+
|
| 488 |
+
#[allow(clippy::too_many_arguments)]
|
| 489 |
+
impl A2C2f {
|
| 490 |
+
fn load(
|
| 491 |
+
vb: VarBuilder,
|
| 492 |
+
in_channels: usize,
|
| 493 |
+
out_channels: usize,
|
| 494 |
+
repeats: usize,
|
| 495 |
+
attention: bool,
|
| 496 |
+
area: usize,
|
| 497 |
+
residual: bool,
|
| 498 |
+
mlp_ratio: f64,
|
| 499 |
+
expansion: f64,
|
| 500 |
+
groups: usize,
|
| 501 |
+
shortcut: bool,
|
| 502 |
+
) -> Result<Self> {
|
| 503 |
+
let hidden = (out_channels as f64 * expansion) as usize;
|
| 504 |
+
let gamma = if attention && residual {
|
| 505 |
+
Some(vb.get(out_channels, "gamma")?)
|
| 506 |
+
} else {
|
| 507 |
+
None
|
| 508 |
+
};
|
| 509 |
+
let mut blocks = Vec::with_capacity(repeats);
|
| 510 |
+
for index in 0..repeats {
|
| 511 |
+
let block_vb = vb.pp(format!("m.{index}"));
|
| 512 |
+
let block = if attention {
|
| 513 |
+
let mut area_blocks = Vec::with_capacity(2);
|
| 514 |
+
for block_index in 0..2 {
|
| 515 |
+
area_blocks.push(AreaBlock::load(
|
| 516 |
+
block_vb.pp(block_index),
|
| 517 |
+
hidden,
|
| 518 |
+
hidden / 32,
|
| 519 |
+
mlp_ratio,
|
| 520 |
+
area,
|
| 521 |
+
)?);
|
| 522 |
+
}
|
| 523 |
+
A2C2fBlock::Attention(area_blocks)
|
| 524 |
+
} else {
|
| 525 |
+
A2C2fBlock::C3k(Box::new(C3k::load(
|
| 526 |
+
block_vb,
|
| 527 |
+
hidden,
|
| 528 |
+
hidden,
|
| 529 |
+
2,
|
| 530 |
+
C3kOptions {
|
| 531 |
+
shortcut,
|
| 532 |
+
groups,
|
| 533 |
+
expansion: 0.5,
|
| 534 |
+
kernel_size: 3,
|
| 535 |
+
},
|
| 536 |
+
)?))
|
| 537 |
+
};
|
| 538 |
+
blocks.push(block);
|
| 539 |
+
}
|
| 540 |
+
Ok(Self {
|
| 541 |
+
cv1: ConvBlock::load(vb.pp("cv1"), in_channels, hidden, 1, 1, None, 1, true)?,
|
| 542 |
+
cv2: ConvBlock::load(
|
| 543 |
+
vb.pp("cv2"),
|
| 544 |
+
(1 + repeats) * hidden,
|
| 545 |
+
out_channels,
|
| 546 |
+
1,
|
| 547 |
+
1,
|
| 548 |
+
None,
|
| 549 |
+
1,
|
| 550 |
+
true,
|
| 551 |
+
)?,
|
| 552 |
+
gamma,
|
| 553 |
+
blocks,
|
| 554 |
+
})
|
| 555 |
+
}
|
| 556 |
+
}
|
| 557 |
+
|
| 558 |
+
impl Module for A2C2f {
|
| 559 |
+
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
| 560 |
+
let mut ys = vec![self.cv1.forward(xs)?];
|
| 561 |
+
for block in &self.blocks {
|
| 562 |
+
ys.push(block.forward(ys.last().expect("a2c2f output"))?);
|
| 563 |
+
}
|
| 564 |
+
let refs = ys.iter().collect::<Vec<_>>();
|
| 565 |
+
let ys = self.cv2.forward(&Tensor::cat(&refs, 1)?)?;
|
| 566 |
+
match &self.gamma {
|
| 567 |
+
Some(gamma) => {
|
| 568 |
+
xs + ys.broadcast_mul(&gamma.reshape((1, gamma.elem_count(), 1, 1))?)?
|
| 569 |
+
}
|
| 570 |
+
None => Ok(ys),
|
| 571 |
+
}
|
| 572 |
+
}
|
| 573 |
+
}
|
| 574 |
+
|
| 575 |
+
#[derive(Debug)]
|
| 576 |
+
struct Yolo12Backbone {
|
| 577 |
+
l0: ConvBlock,
|
| 578 |
+
l1: ConvBlock,
|
| 579 |
+
l2: C3k2,
|
| 580 |
+
l3: ConvBlock,
|
| 581 |
+
l4: C3k2,
|
| 582 |
+
l5: ConvBlock,
|
| 583 |
+
l6: A2C2f,
|
| 584 |
+
l7: ConvBlock,
|
| 585 |
+
l8: A2C2f,
|
| 586 |
+
}
|
| 587 |
+
|
| 588 |
+
impl Yolo12Backbone {
|
| 589 |
+
fn load(vb: VarBuilder, scale: Yolo12Scale) -> Result<Self> {
|
| 590 |
+
let m = scale.multiples();
|
| 591 |
+
let c64 = m.channels(64);
|
| 592 |
+
let c128 = m.channels(128);
|
| 593 |
+
let c256 = m.channels(256);
|
| 594 |
+
let c512 = m.channels(512);
|
| 595 |
+
let c1024 = m.channels(1024);
|
| 596 |
+
let a2_residual = scale.uses_a2_residual();
|
| 597 |
+
let mlp_ratio = if a2_residual { 1.2 } else { 2.0 };
|
| 598 |
+
|
| 599 |
+
Ok(Self {
|
| 600 |
+
l0: ConvBlock::load(vb.pp("model.0"), 3, c64, 3, 2, None, 1, true)?,
|
| 601 |
+
l1: ConvBlock::load(vb.pp("model.1"), c64, c128, 3, 2, None, 1, true)?,
|
| 602 |
+
l2: C3k2::load(
|
| 603 |
+
vb.pp("model.2"),
|
| 604 |
+
c128,
|
| 605 |
+
c256,
|
| 606 |
+
m.repeats(2),
|
| 607 |
+
C3k2Options {
|
| 608 |
+
use_c3k: scale.uses_large_c3k(),
|
| 609 |
+
expansion: 0.25,
|
| 610 |
+
groups: 1,
|
| 611 |
+
shortcut: true,
|
| 612 |
+
},
|
| 613 |
+
)?,
|
| 614 |
+
l3: ConvBlock::load(vb.pp("model.3"), c256, c256, 3, 2, None, 1, true)?,
|
| 615 |
+
l4: C3k2::load(
|
| 616 |
+
vb.pp("model.4"),
|
| 617 |
+
c256,
|
| 618 |
+
c512,
|
| 619 |
+
m.repeats(2),
|
| 620 |
+
C3k2Options {
|
| 621 |
+
use_c3k: scale.uses_large_c3k(),
|
| 622 |
+
expansion: 0.25,
|
| 623 |
+
groups: 1,
|
| 624 |
+
shortcut: true,
|
| 625 |
+
},
|
| 626 |
+
)?,
|
| 627 |
+
l5: ConvBlock::load(vb.pp("model.5"), c512, c512, 3, 2, None, 1, true)?,
|
| 628 |
+
l6: A2C2f::load(
|
| 629 |
+
vb.pp("model.6"),
|
| 630 |
+
c512,
|
| 631 |
+
c512,
|
| 632 |
+
m.repeats(4),
|
| 633 |
+
true,
|
| 634 |
+
4,
|
| 635 |
+
a2_residual,
|
| 636 |
+
mlp_ratio,
|
| 637 |
+
0.5,
|
| 638 |
+
1,
|
| 639 |
+
true,
|
| 640 |
+
)?,
|
| 641 |
+
l7: ConvBlock::load(vb.pp("model.7"), c512, c1024, 3, 2, None, 1, true)?,
|
| 642 |
+
l8: A2C2f::load(
|
| 643 |
+
vb.pp("model.8"),
|
| 644 |
+
c1024,
|
| 645 |
+
c1024,
|
| 646 |
+
m.repeats(4),
|
| 647 |
+
true,
|
| 648 |
+
1,
|
| 649 |
+
a2_residual,
|
| 650 |
+
mlp_ratio,
|
| 651 |
+
0.5,
|
| 652 |
+
1,
|
| 653 |
+
true,
|
| 654 |
+
)?,
|
| 655 |
+
})
|
| 656 |
+
}
|
| 657 |
+
|
| 658 |
+
fn forward(&self, xs: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
|
| 659 |
+
let x0 = self.l0.forward(xs)?;
|
| 660 |
+
let x1 = self.l1.forward(&x0)?;
|
| 661 |
+
let x2 = self.l2.forward(&x1)?;
|
| 662 |
+
let x3 = self.l3.forward(&x2)?;
|
| 663 |
+
let x4 = self.l4.forward(&x3)?;
|
| 664 |
+
let x5 = self.l5.forward(&x4)?;
|
| 665 |
+
let x6 = self.l6.forward(&x5)?;
|
| 666 |
+
let x7 = self.l7.forward(&x6)?;
|
| 667 |
+
let x8 = self.l8.forward(&x7)?;
|
| 668 |
+
Ok((x4, x6, x8))
|
| 669 |
+
}
|
| 670 |
+
}
|
| 671 |
+
|
| 672 |
+
#[derive(Debug)]
|
| 673 |
+
struct Yolo12Neck {
|
| 674 |
+
upsample: Upsample,
|
| 675 |
+
l11: A2C2f,
|
| 676 |
+
l14: A2C2f,
|
| 677 |
+
l15: ConvBlock,
|
| 678 |
+
l17: A2C2f,
|
| 679 |
+
l18: ConvBlock,
|
| 680 |
+
l20: C3k2,
|
| 681 |
+
}
|
| 682 |
+
|
| 683 |
+
impl Yolo12Neck {
|
| 684 |
+
fn load(vb: VarBuilder, scale: Yolo12Scale) -> Result<Self> {
|
| 685 |
+
let m = scale.multiples();
|
| 686 |
+
let c256 = m.channels(256);
|
| 687 |
+
let c512 = m.channels(512);
|
| 688 |
+
let c1024 = m.channels(1024);
|
| 689 |
+
let repeats = m.repeats(2);
|
| 690 |
+
Ok(Self {
|
| 691 |
+
upsample: Upsample::new(2),
|
| 692 |
+
l11: A2C2f::load(
|
| 693 |
+
vb.pp("model.11"),
|
| 694 |
+
c1024 + c512,
|
| 695 |
+
c512,
|
| 696 |
+
repeats,
|
| 697 |
+
false,
|
| 698 |
+
1,
|
| 699 |
+
false,
|
| 700 |
+
2.0,
|
| 701 |
+
0.5,
|
| 702 |
+
1,
|
| 703 |
+
true,
|
| 704 |
+
)?,
|
| 705 |
+
l14: A2C2f::load(
|
| 706 |
+
vb.pp("model.14"),
|
| 707 |
+
c512 + c512,
|
| 708 |
+
c256,
|
| 709 |
+
repeats,
|
| 710 |
+
false,
|
| 711 |
+
1,
|
| 712 |
+
false,
|
| 713 |
+
2.0,
|
| 714 |
+
0.5,
|
| 715 |
+
1,
|
| 716 |
+
true,
|
| 717 |
+
)?,
|
| 718 |
+
l15: ConvBlock::load(vb.pp("model.15"), c256, c256, 3, 2, None, 1, true)?,
|
| 719 |
+
l17: A2C2f::load(
|
| 720 |
+
vb.pp("model.17"),
|
| 721 |
+
c256 + c512,
|
| 722 |
+
c512,
|
| 723 |
+
repeats,
|
| 724 |
+
false,
|
| 725 |
+
1,
|
| 726 |
+
false,
|
| 727 |
+
2.0,
|
| 728 |
+
0.5,
|
| 729 |
+
1,
|
| 730 |
+
true,
|
| 731 |
+
)?,
|
| 732 |
+
l18: ConvBlock::load(vb.pp("model.18"), c512, c512, 3, 2, None, 1, true)?,
|
| 733 |
+
l20: C3k2::load(
|
| 734 |
+
vb.pp("model.20"),
|
| 735 |
+
c512 + c1024,
|
| 736 |
+
c1024,
|
| 737 |
+
repeats,
|
| 738 |
+
C3k2Options {
|
| 739 |
+
use_c3k: true,
|
| 740 |
+
expansion: 0.5,
|
| 741 |
+
groups: 1,
|
| 742 |
+
shortcut: true,
|
| 743 |
+
},
|
| 744 |
+
)?,
|
| 745 |
+
})
|
| 746 |
+
}
|
| 747 |
+
|
| 748 |
+
fn forward(&self, p3: &Tensor, p4: &Tensor, p5: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
|
| 749 |
+
let x11 = self
|
| 750 |
+
.l11
|
| 751 |
+
.forward(&Tensor::cat(&[&self.upsample.forward(p5)?, p4], 1)?)?;
|
| 752 |
+
let x14 = self
|
| 753 |
+
.l14
|
| 754 |
+
.forward(&Tensor::cat(&[&self.upsample.forward(&x11)?, p3], 1)?)?;
|
| 755 |
+
let x17 = self
|
| 756 |
+
.l17
|
| 757 |
+
.forward(&Tensor::cat(&[&self.l15.forward(&x14)?, &x11], 1)?)?;
|
| 758 |
+
let x20 = self
|
| 759 |
+
.l20
|
| 760 |
+
.forward(&Tensor::cat(&[&self.l18.forward(&x17)?, p5], 1)?)?;
|
| 761 |
+
Ok((x14, x17, x20))
|
| 762 |
+
}
|
| 763 |
+
}
|
| 764 |
+
|
| 765 |
+
#[derive(Debug)]
|
| 766 |
+
struct Dfl {
|
| 767 |
+
conv: Conv2d,
|
| 768 |
+
reg_max: usize,
|
| 769 |
+
}
|
| 770 |
+
|
| 771 |
+
impl Dfl {
|
| 772 |
+
fn load(vb: VarBuilder, reg_max: usize) -> Result<Self> {
|
| 773 |
+
Ok(Self {
|
| 774 |
+
conv: conv2d_no_bias(reg_max, 1, 1, Default::default(), vb.pp("conv"))?,
|
| 775 |
+
reg_max,
|
| 776 |
+
})
|
| 777 |
+
}
|
| 778 |
+
}
|
| 779 |
+
|
| 780 |
+
impl Module for Dfl {
|
| 781 |
+
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
| 782 |
+
let (batch, _, anchors) = xs.dims3()?;
|
| 783 |
+
let xs = xs
|
| 784 |
+
.reshape((batch, 4, self.reg_max, anchors))?
|
| 785 |
+
.transpose(2, 1)?;
|
| 786 |
+
let xs = candle_nn::ops::softmax(&xs, 1)?;
|
| 787 |
+
self.conv.forward(&xs)?.reshape((batch, 4, anchors))
|
| 788 |
+
}
|
| 789 |
+
}
|
| 790 |
+
|
| 791 |
+
#[derive(Debug)]
|
| 792 |
+
struct DetectCv3 {
|
| 793 |
+
dw0: ConvBlock,
|
| 794 |
+
pw0: ConvBlock,
|
| 795 |
+
dw1: ConvBlock,
|
| 796 |
+
pw1: ConvBlock,
|
| 797 |
+
conv: Conv2d,
|
| 798 |
+
}
|
| 799 |
+
|
| 800 |
+
impl DetectCv3 {
|
| 801 |
+
fn load(vb: VarBuilder, in_channels: usize, hidden: usize, num_classes: usize) -> Result<Self> {
|
| 802 |
+
Ok(Self {
|
| 803 |
+
dw0: ConvBlock::load(
|
| 804 |
+
vb.pp("0.0"),
|
| 805 |
+
in_channels,
|
| 806 |
+
in_channels,
|
| 807 |
+
3,
|
| 808 |
+
1,
|
| 809 |
+
None,
|
| 810 |
+
in_channels,
|
| 811 |
+
true,
|
| 812 |
+
)?,
|
| 813 |
+
pw0: ConvBlock::load(vb.pp("0.1"), in_channels, hidden, 1, 1, None, 1, true)?,
|
| 814 |
+
dw1: ConvBlock::load(vb.pp("1.0"), hidden, hidden, 3, 1, None, hidden, true)?,
|
| 815 |
+
pw1: ConvBlock::load(vb.pp("1.1"), hidden, hidden, 1, 1, None, 1, true)?,
|
| 816 |
+
conv: conv2d(hidden, num_classes, 1, Default::default(), vb.pp("2"))?,
|
| 817 |
+
})
|
| 818 |
+
}
|
| 819 |
+
}
|
| 820 |
+
|
| 821 |
+
impl Module for DetectCv3 {
|
| 822 |
+
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
| 823 |
+
let xs = self.pw0.forward(&self.dw0.forward(xs)?)?;
|
| 824 |
+
let xs = self.pw1.forward(&self.dw1.forward(&xs)?)?;
|
| 825 |
+
self.conv.forward(&xs)
|
| 826 |
+
}
|
| 827 |
+
}
|
| 828 |
+
|
| 829 |
+
#[derive(Debug)]
|
| 830 |
+
struct DetectionHead {
|
| 831 |
+
dfl: Dfl,
|
| 832 |
+
cv2: [(ConvBlock, ConvBlock, Conv2d); 3],
|
| 833 |
+
cv3: [DetectCv3; 3],
|
| 834 |
+
reg_max: usize,
|
| 835 |
+
no: usize,
|
| 836 |
+
}
|
| 837 |
+
|
| 838 |
+
impl DetectionHead {
|
| 839 |
+
fn load(vb: VarBuilder, num_classes: usize, filters: (usize, usize, usize)) -> Result<Self> {
|
| 840 |
+
let c2 = filters.0.div_ceil(4).max(REG_MAX * 4).max(16);
|
| 841 |
+
let c3 = filters.0.max(num_classes.min(100));
|
| 842 |
+
Ok(Self {
|
| 843 |
+
dfl: Dfl::load(vb.pp("dfl"), REG_MAX)?,
|
| 844 |
+
cv2: [
|
| 845 |
+
Self::load_cv2(vb.pp("cv2.0"), filters.0, c2)?,
|
| 846 |
+
Self::load_cv2(vb.pp("cv2.1"), filters.1, c2)?,
|
| 847 |
+
Self::load_cv2(vb.pp("cv2.2"), filters.2, c2)?,
|
| 848 |
+
],
|
| 849 |
+
cv3: [
|
| 850 |
+
DetectCv3::load(vb.pp("cv3.0"), filters.0, c3, num_classes)?,
|
| 851 |
+
DetectCv3::load(vb.pp("cv3.1"), filters.1, c3, num_classes)?,
|
| 852 |
+
DetectCv3::load(vb.pp("cv3.2"), filters.2, c3, num_classes)?,
|
| 853 |
+
],
|
| 854 |
+
reg_max: REG_MAX,
|
| 855 |
+
no: num_classes + REG_MAX * 4,
|
| 856 |
+
})
|
| 857 |
+
}
|
| 858 |
+
|
| 859 |
+
fn load_cv2(
|
| 860 |
+
vb: VarBuilder,
|
| 861 |
+
in_channels: usize,
|
| 862 |
+
hidden: usize,
|
| 863 |
+
) -> Result<(ConvBlock, ConvBlock, Conv2d)> {
|
| 864 |
+
Ok((
|
| 865 |
+
ConvBlock::load(vb.pp("0"), in_channels, hidden, 3, 1, None, 1, true)?,
|
| 866 |
+
ConvBlock::load(vb.pp("1"), hidden, hidden, 3, 1, None, 1, true)?,
|
| 867 |
+
conv2d(hidden, REG_MAX * 4, 1, Default::default(), vb.pp("2"))?,
|
| 868 |
+
))
|
| 869 |
+
}
|
| 870 |
+
|
| 871 |
+
fn forward_cv2(block: &(ConvBlock, ConvBlock, Conv2d), xs: &Tensor) -> Result<Tensor> {
|
| 872 |
+
block.2.forward(&block.1.forward(&block.0.forward(xs)?)?)
|
| 873 |
+
}
|
| 874 |
+
|
| 875 |
+
fn forward(&self, xs0: &Tensor, xs1: &Tensor, xs2: &Tensor) -> Result<Tensor> {
|
| 876 |
+
let xs0 = Tensor::cat(
|
| 877 |
+
&[
|
| 878 |
+
&Self::forward_cv2(&self.cv2[0], xs0)?,
|
| 879 |
+
&self.cv3[0].forward(xs0)?,
|
| 880 |
+
],
|
| 881 |
+
1,
|
| 882 |
+
)?;
|
| 883 |
+
let xs1 = Tensor::cat(
|
| 884 |
+
&[
|
| 885 |
+
&Self::forward_cv2(&self.cv2[1], xs1)?,
|
| 886 |
+
&self.cv3[1].forward(xs1)?,
|
| 887 |
+
],
|
| 888 |
+
1,
|
| 889 |
+
)?;
|
| 890 |
+
let xs2 = Tensor::cat(
|
| 891 |
+
&[
|
| 892 |
+
&Self::forward_cv2(&self.cv2[2], xs2)?,
|
| 893 |
+
&self.cv3[2].forward(xs2)?,
|
| 894 |
+
],
|
| 895 |
+
1,
|
| 896 |
+
)?;
|
| 897 |
+
|
| 898 |
+
let (anchors, strides) = make_anchors(&xs0, &xs1, &xs2, (8, 16, 32), 0.5)?;
|
| 899 |
+
let anchors = anchors.transpose(0, 1)?.unsqueeze(0)?;
|
| 900 |
+
let strides = strides.transpose(0, 1)?;
|
| 901 |
+
|
| 902 |
+
let reshape = |xs: &Tensor| {
|
| 903 |
+
let batch = xs.dim(0)?;
|
| 904 |
+
xs.reshape((batch, self.no, xs.elem_count() / (batch * self.no)))
|
| 905 |
+
};
|
| 906 |
+
let ys0 = reshape(&xs0)?;
|
| 907 |
+
let ys1 = reshape(&xs1)?;
|
| 908 |
+
let ys2 = reshape(&xs2)?;
|
| 909 |
+
let x_cat = Tensor::cat(&[&ys0, &ys1, &ys2], 2)?;
|
| 910 |
+
let box_ = x_cat.i((.., ..self.reg_max * 4, ..))?;
|
| 911 |
+
let cls = x_cat.i((.., self.reg_max * 4.., ..))?;
|
| 912 |
+
let dbox = dist2bbox(&self.dfl.forward(&box_)?, &anchors)?.broadcast_mul(&strides)?;
|
| 913 |
+
Tensor::cat(&[&dbox, &candle_nn::ops::sigmoid(&cls)?], 1)
|
| 914 |
+
}
|
| 915 |
+
}
|
| 916 |
+
|
| 917 |
+
fn make_anchors(
|
| 918 |
+
xs0: &Tensor,
|
| 919 |
+
xs1: &Tensor,
|
| 920 |
+
xs2: &Tensor,
|
| 921 |
+
strides: (usize, usize, usize),
|
| 922 |
+
grid_cell_offset: f64,
|
| 923 |
+
) -> Result<(Tensor, Tensor)> {
|
| 924 |
+
let device = xs0.device();
|
| 925 |
+
let dtype = xs0.dtype();
|
| 926 |
+
let mut anchor_points = Vec::with_capacity(3);
|
| 927 |
+
let mut stride_tensors = Vec::with_capacity(3);
|
| 928 |
+
for (xs, stride) in [(xs0, strides.0), (xs1, strides.1), (xs2, strides.2)] {
|
| 929 |
+
let (_, _, h, w) = xs.dims4()?;
|
| 930 |
+
let sx = (Tensor::arange(0, w as u32, device)?.to_dtype(dtype)? + grid_cell_offset)?;
|
| 931 |
+
let sy = (Tensor::arange(0, h as u32, device)?.to_dtype(dtype)? + grid_cell_offset)?;
|
| 932 |
+
let sx = sx
|
| 933 |
+
.reshape((1, sx.elem_count()))?
|
| 934 |
+
.repeat((h, 1))?
|
| 935 |
+
.flatten_all()?;
|
| 936 |
+
let sy = sy
|
| 937 |
+
.reshape((sy.elem_count(), 1))?
|
| 938 |
+
.repeat((1, w))?
|
| 939 |
+
.flatten_all()?;
|
| 940 |
+
anchor_points.push(Tensor::stack(&[&sx, &sy], D::Minus1)?);
|
| 941 |
+
stride_tensors.push((Tensor::ones(h * w, dtype, device)? * stride as f64)?);
|
| 942 |
+
}
|
| 943 |
+
let anchor_points = Tensor::cat(anchor_points.as_slice(), 0)?;
|
| 944 |
+
let stride_tensor = Tensor::cat(stride_tensors.as_slice(), 0)?.unsqueeze(1)?;
|
| 945 |
+
Ok((anchor_points, stride_tensor))
|
| 946 |
+
}
|
| 947 |
+
|
| 948 |
+
fn dist2bbox(distance: &Tensor, anchor_points: &Tensor) -> Result<Tensor> {
|
| 949 |
+
let chunks = distance.chunk(2, 1)?;
|
| 950 |
+
let lt = &chunks[0];
|
| 951 |
+
let rb = &chunks[1];
|
| 952 |
+
let x1y1 = anchor_points.sub(lt)?;
|
| 953 |
+
let x2y2 = anchor_points.add(rb)?;
|
| 954 |
+
let c_xy = ((&x1y1 + &x2y2)? * 0.5)?;
|
| 955 |
+
let wh = (&x2y2 - &x1y1)?;
|
| 956 |
+
Tensor::cat(&[&c_xy, &wh], 1)
|
| 957 |
+
}
|
| 958 |
+
|
| 959 |
+
#[derive(Debug)]
|
| 960 |
+
pub struct Yolo12 {
|
| 961 |
+
backbone: Yolo12Backbone,
|
| 962 |
+
neck: Yolo12Neck,
|
| 963 |
+
head: DetectionHead,
|
| 964 |
+
}
|
| 965 |
+
|
| 966 |
+
impl Yolo12 {
|
| 967 |
+
pub fn load(vb: VarBuilder, scale: Yolo12Scale, num_classes: usize) -> Result<Self> {
|
| 968 |
+
let m = scale.multiples();
|
| 969 |
+
let filters = (m.channels(256), m.channels(512), m.channels(1024));
|
| 970 |
+
Ok(Self {
|
| 971 |
+
backbone: Yolo12Backbone::load(vb.clone(), scale)?,
|
| 972 |
+
neck: Yolo12Neck::load(vb.clone(), scale)?,
|
| 973 |
+
head: DetectionHead::load(vb.pp("model.21"), num_classes, filters)?,
|
| 974 |
+
})
|
| 975 |
+
}
|
| 976 |
+
|
| 977 |
+
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
| 978 |
+
let (p3, p4, p5) = self.backbone.forward(xs)?;
|
| 979 |
+
let (h1, h2, h3) = self.neck.forward(&p3, &p4, &p5)?;
|
| 980 |
+
self.head.forward(&h1, &h2, &h3)
|
| 981 |
+
}
|
| 982 |
+
}
|
koharu-ml/src/lib.rs
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
mod hf_hub;
|
| 2 |
|
|
|
|
| 3 |
pub mod aot_inpainting;
|
| 4 |
pub mod comic_text_bubble_detector;
|
| 5 |
pub mod comic_text_detector;
|
|
|
|
| 1 |
mod hf_hub;
|
| 2 |
|
| 3 |
+
pub mod anime_text;
|
| 4 |
pub mod aot_inpainting;
|
| 5 |
pub mod comic_text_bubble_detector;
|
| 6 |
pub mod comic_text_detector;
|
koharu-ml/tests/anime_text.rs
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
use std::path::Path;
|
| 2 |
+
|
| 3 |
+
use koharu_ml::anime_text::AnimeTextDetector;
|
| 4 |
+
|
| 5 |
+
mod support;
|
| 6 |
+
|
| 7 |
+
#[tokio::test]
|
| 8 |
+
#[ignore = "requires model download and is not critical for CI"]
|
| 9 |
+
async fn anime_text_yolo() -> anyhow::Result<()> {
|
| 10 |
+
let runtime = support::cpu_runtime();
|
| 11 |
+
let model = AnimeTextDetector::load(&runtime, false).await?;
|
| 12 |
+
|
| 13 |
+
let image = image::open(Path::new(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures/1.jpg"))?;
|
| 14 |
+
let detection = model.inference(&image)?;
|
| 15 |
+
|
| 16 |
+
assert_eq!(detection.image_width, image.width());
|
| 17 |
+
assert_eq!(detection.image_height, image.height());
|
| 18 |
+
assert!(
|
| 19 |
+
!detection.text_blocks.is_empty(),
|
| 20 |
+
"expected anime text YOLO to detect text blocks"
|
| 21 |
+
);
|
| 22 |
+
assert!(
|
| 23 |
+
detection
|
| 24 |
+
.text_blocks
|
| 25 |
+
.iter()
|
| 26 |
+
.all(|block| block.detector.as_deref() == Some("anime-text-yolo"))
|
| 27 |
+
);
|
| 28 |
+
|
| 29 |
+
Ok(())
|
| 30 |
+
}
|