Mayo commited on
Commit
21d917f
·
unverified ·
1 Parent(s): f9d9a50

feat: add mit48px OCR backend

Browse files
koharu-ml/Cargo.toml CHANGED
@@ -78,6 +78,10 @@ path = "bin/llm.rs"
78
  name = "manga-ocr"
79
  path = "bin/manga-ocr.rs"
80
 
 
 
 
 
81
  [[bin]]
82
  name = "font-detect"
83
  path = "bin/font-detect.rs"
 
78
  name = "manga-ocr"
79
  path = "bin/manga-ocr.rs"
80
 
81
+ [[bin]]
82
+ name = "mit48px-ocr"
83
+ path = "bin/mit48px-ocr.rs"
84
+
85
  [[bin]]
86
  name = "font-detect"
87
  path = "bin/font-detect.rs"
koharu-ml/README.md CHANGED
@@ -5,7 +5,8 @@ Model wrappers and CLI tools for the Koharu app.
5
  ## Modules
6
 
7
  - `comic_text_detector`: ONNX model that finds speech bubbles/text blocks and returns bounding boxes plus a segmentation mask.
8
- - `manga_ocr`: encoder/decoder OCR pipeline that reads cropped text regions.
 
9
  - `lama`: LaMa inpainting with tiled blending to remove text using a mask.
10
  - `llm`: quantized GGUF loader (Llama or Qwen2) using candle with chat-style prompting and generation controls.
11
  - `font_detect`: Candle ResNet50 that reproduces YuzuMarker.FontDetection (CJK font/style classifier).
@@ -14,6 +15,7 @@ Model wrappers and CLI tools for the Koharu app.
14
 
15
  ```bash
16
  cargo run -p koharu-models --bin comic-text-detector -- --input page.png --output boxes.png
 
17
  cargo run -p koharu-models --bin manga-ocr -- --input bubble.png
18
  cargo run -p koharu-models --bin lama -- --input page.png --mask mask.png --output filled.png
19
  cargo run -p koharu-models --bin llm -- --prompt "konnichiwa" --model vntl-llama3-8b-v2
 
5
  ## Modules
6
 
7
  - `comic_text_detector`: ONNX model that finds speech bubbles/text blocks and returns bounding boxes plus a segmentation mask.
8
+ - `mit48px_ocr`: autoregressive OCR pipeline that reads per-line text regions and is the default document OCR backend.
9
+ - `manga_ocr`: legacy encoder/decoder OCR pipeline that reads cropped text regions.
10
  - `lama`: LaMa inpainting with tiled blending to remove text using a mask.
11
  - `llm`: quantized GGUF loader (Llama or Qwen2) using candle with chat-style prompting and generation controls.
12
  - `font_detect`: Candle ResNet50 that reproduces YuzuMarker.FontDetection (CJK font/style classifier).
 
15
 
16
  ```bash
17
  cargo run -p koharu-models --bin comic-text-detector -- --input page.png --output boxes.png
18
+ cargo run -p koharu-models --bin mit48px-ocr -- --input bubble.png
19
  cargo run -p koharu-models --bin manga-ocr -- --input bubble.png
20
  cargo run -p koharu-models --bin lama -- --input page.png --mask mask.png --output filled.png
21
  cargo run -p koharu-models --bin llm -- --prompt "konnichiwa" --model vntl-llama3-8b-v2
koharu-ml/bin/mit48px-ocr.rs ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use clap::Parser;
2
+ use koharu_ml::mit48px_ocr::{Mit48pxBlockPrediction, Mit48pxOcr, Mit48pxPrediction};
3
+ use koharu_types::TextBlock;
4
+
5
+ #[path = "common.rs"]
6
+ mod common;
7
+
8
+ #[derive(Parser)]
9
+ struct Cli {
10
+ #[arg(long, value_name = "FILE")]
11
+ input: String,
12
+
13
+ #[arg(long, value_name = "DIR")]
14
+ model_dir: Option<String>,
15
+
16
+ #[arg(long, value_name = "FILE")]
17
+ blocks_json: Option<String>,
18
+
19
+ #[arg(long, value_name = "FILE")]
20
+ json_output: Option<String>,
21
+
22
+ #[arg(long, default_value_t = false)]
23
+ cpu: bool,
24
+ }
25
+
26
+ #[derive(serde::Serialize)]
27
+ #[serde(rename_all = "camelCase")]
28
+ struct OutputEnvelope {
29
+ regions: Option<Vec<Mit48pxPrediction>>,
30
+ blocks: Option<Vec<Mit48pxBlockPrediction>>,
31
+ }
32
+
33
+ #[tokio::main]
34
+ async fn main() -> anyhow::Result<()> {
35
+ common::init_tracing();
36
+
37
+ let cli = Cli::parse();
38
+ let image = image::open(&cli.input)?;
39
+ let model = if let Some(model_dir) = &cli.model_dir {
40
+ Mit48pxOcr::load_from_dir(model_dir, cli.cpu)?
41
+ } else {
42
+ Mit48pxOcr::load(cli.cpu).await?
43
+ };
44
+
45
+ let output = if let Some(blocks_path) = &cli.blocks_json {
46
+ let blocks: Vec<TextBlock> = serde_json::from_str(&std::fs::read_to_string(blocks_path)?)?;
47
+ let predictions = model.inference_text_blocks(&image, &blocks)?;
48
+ for prediction in &predictions {
49
+ println!(
50
+ "#{} {:.4} {}",
51
+ prediction.block_index, prediction.confidence, prediction.text
52
+ );
53
+ }
54
+ OutputEnvelope {
55
+ regions: None,
56
+ blocks: Some(predictions),
57
+ }
58
+ } else {
59
+ let predictions = model.inference_regions(&[image])?;
60
+ for prediction in &predictions {
61
+ println!("{:.4} {}", prediction.confidence, prediction.text);
62
+ }
63
+ OutputEnvelope {
64
+ regions: Some(predictions),
65
+ blocks: None,
66
+ }
67
+ };
68
+
69
+ if let Some(path) = &cli.json_output {
70
+ std::fs::write(path, serde_json::to_string_pretty(&output)?)?;
71
+ }
72
+
73
+ Ok(())
74
+ }
koharu-ml/src/facade.rs CHANGED
@@ -5,7 +5,7 @@ use koharu_types::{Document, FontPrediction, SerializableDynamicImage};
5
  use crate::comic_text_detector::{self, ComicTextDetector};
6
  use crate::font_detector::{self, FontDetector};
7
  use crate::lama::{self, Lama};
8
- use crate::manga_ocr::{self, MangaOcr};
9
 
10
  const NEAR_BLACK_THRESHOLD: u8 = 12;
11
  const GRAY_NEAR_BLACK_THRESHOLD: u8 = 60;
@@ -69,7 +69,7 @@ fn normalize_font_prediction(prediction: &mut FontPrediction) {
69
 
70
  pub struct Model {
71
  dialog_detector: ComicTextDetector,
72
- ocr: MangaOcr,
73
  lama: Lama,
74
  font_detector: FontDetector,
75
  }
@@ -78,7 +78,7 @@ impl Model {
78
  pub async fn new(use_cpu: bool) -> Result<Self> {
79
  Ok(Self {
80
  dialog_detector: ComicTextDetector::load(use_cpu).await?,
81
- ocr: MangaOcr::load(use_cpu).await?,
82
  lama: Lama::load(use_cpu).await?,
83
  font_detector: FontDetector::load(use_cpu).await?,
84
  })
@@ -122,22 +122,14 @@ impl Model {
122
  return Ok(());
123
  }
124
 
125
- let crops: Vec<DynamicImage> = doc
126
- .text_blocks
127
- .iter()
128
- .map(|block| {
129
- doc.image.crop_imm(
130
- block.x as u32,
131
- block.y as u32,
132
- block.width as u32,
133
- block.height as u32,
134
- )
135
- })
136
- .collect();
137
- let texts = self.ocr.inference(&crops)?;
138
-
139
- for (block, text) in doc.text_blocks.iter_mut().zip(texts) {
140
- block.text = text.into();
141
  }
142
 
143
  Ok(())
@@ -192,7 +184,7 @@ impl Model {
192
 
193
  pub async fn prefetch() -> Result<()> {
194
  comic_text_detector::prefetch().await?;
195
- manga_ocr::prefetch().await?;
196
  lama::prefetch().await?;
197
  font_detector::prefetch().await?;
198
 
 
5
  use crate::comic_text_detector::{self, ComicTextDetector};
6
  use crate::font_detector::{self, FontDetector};
7
  use crate::lama::{self, Lama};
8
+ use crate::mit48px_ocr::{self, Mit48pxOcr};
9
 
10
  const NEAR_BLACK_THRESHOLD: u8 = 12;
11
  const GRAY_NEAR_BLACK_THRESHOLD: u8 = 60;
 
69
 
70
  pub struct Model {
71
  dialog_detector: ComicTextDetector,
72
+ ocr: Mit48pxOcr,
73
  lama: Lama,
74
  font_detector: FontDetector,
75
  }
 
78
  pub async fn new(use_cpu: bool) -> Result<Self> {
79
  Ok(Self {
80
  dialog_detector: ComicTextDetector::load(use_cpu).await?,
81
+ ocr: Mit48pxOcr::load(use_cpu).await?,
82
  lama: Lama::load(use_cpu).await?,
83
  font_detector: FontDetector::load(use_cpu).await?,
84
  })
 
122
  return Ok(());
123
  }
124
 
125
+ let predictions = self
126
+ .ocr
127
+ .inference_text_blocks(&doc.image, &doc.text_blocks)?;
128
+
129
+ for prediction in predictions {
130
+ if let Some(block) = doc.text_blocks.get_mut(prediction.block_index) {
131
+ block.text = Some(prediction.text);
132
+ }
 
 
 
 
 
 
 
 
133
  }
134
 
135
  Ok(())
 
184
 
185
  pub async fn prefetch() -> Result<()> {
186
  comic_text_detector::prefetch().await?;
187
+ mit48px_ocr::prefetch().await?;
188
  lama::prefetch().await?;
189
  font_detector::prefetch().await?;
190
 
koharu-ml/src/lib.rs CHANGED
@@ -7,6 +7,7 @@ pub mod lama;
7
  pub mod llm;
8
  pub mod loading;
9
  pub mod manga_ocr;
 
10
 
11
  use anyhow::Result;
12
  use candle_core::utils::metal_is_available;
 
7
  pub mod llm;
8
  pub mod loading;
9
  pub mod manga_ocr;
10
+ pub mod mit48px_ocr;
11
 
12
  use anyhow::Result;
13
  use candle_core::utils::metal_is_available;
koharu-ml/src/mit48px_ocr/mod.rs ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ mod model;
2
+
3
+ use std::path::{Path, PathBuf};
4
+
5
+ use anyhow::{Context, Result};
6
+ use candle_core::{DType, Device, Tensor};
7
+ use candle_nn::VarBuilder;
8
+ use image::{DynamicImage, RgbImage, imageops::FilterType};
9
+ use koharu_types::TextBlock;
10
+ use serde::{Deserialize, Serialize};
11
+ use tracing::instrument;
12
+
13
+ use model::{Mit48pxModel, RawPrediction};
14
+
15
+ use crate::{comic_text_detector::extract_text_block_regions, define_models, device, loading};
16
+
17
+ const OCR_CHUNK_SIZE: usize = 16;
18
+
19
+ define_models! {
20
+ Config => ("mayocream/mit48px-ocr", "config.json"),
21
+ Dictionary => ("mayocream/mit48px-ocr", "alphabet-all-v7.txt"),
22
+ Model => ("mayocream/mit48px-ocr", "model.safetensors"),
23
+ }
24
+
25
+ #[derive(Debug, Clone, Serialize, Deserialize)]
26
+ pub struct Mit48pxConfig {
27
+ pub text_height: u32,
28
+ pub max_width: u32,
29
+ pub embd_dim: usize,
30
+ pub num_heads: usize,
31
+ pub encoder_layers: usize,
32
+ pub decoder_layers: usize,
33
+ pub beam_size_default: usize,
34
+ pub max_seq_length_default: usize,
35
+ pub pad_token_id: u32,
36
+ pub bos_token_id: u32,
37
+ pub eos_token_id: u32,
38
+ pub space_token: String,
39
+ pub dictionary_file: String,
40
+ }
41
+
42
+ #[derive(Debug, Clone, Serialize, Deserialize)]
43
+ #[serde(rename_all = "camelCase")]
44
+ pub struct Mit48pxPrediction {
45
+ pub text: String,
46
+ pub confidence: f32,
47
+ pub text_color: [u8; 3],
48
+ pub stroke_color: [u8; 3],
49
+ pub has_text_color: bool,
50
+ pub has_stroke_color: bool,
51
+ }
52
+
53
+ #[derive(Debug, Clone, Serialize, Deserialize)]
54
+ #[serde(rename_all = "camelCase")]
55
+ pub struct Mit48pxBlockPrediction {
56
+ pub block_index: usize,
57
+ pub text: String,
58
+ pub confidence: f32,
59
+ pub text_color: [u8; 3],
60
+ pub stroke_color: [u8; 3],
61
+ }
62
+
63
+ struct PreparedBatch {
64
+ tensor: Tensor,
65
+ widths: Vec<u32>,
66
+ }
67
+
68
+ struct ModelFiles {
69
+ config: PathBuf,
70
+ dictionary: PathBuf,
71
+ weights: PathBuf,
72
+ }
73
+
74
+ pub struct Mit48pxOcr {
75
+ model: Mit48pxModel,
76
+ config: Mit48pxConfig,
77
+ dictionary: Vec<String>,
78
+ device: Device,
79
+ }
80
+
81
+ impl Mit48pxOcr {
82
+ pub async fn load(use_cpu: bool) -> Result<Self> {
83
+ let files = ModelFiles {
84
+ config: loading::resolve_manifest_path(Manifest::Config.get()).await?,
85
+ dictionary: loading::resolve_manifest_path(Manifest::Dictionary.get()).await?,
86
+ weights: loading::resolve_manifest_path(Manifest::Model.get()).await?,
87
+ };
88
+ Self::load_from_files(files, use_cpu)
89
+ }
90
+
91
+ pub fn load_from_dir(dir: impl AsRef<Path>, use_cpu: bool) -> Result<Self> {
92
+ let dir = dir.as_ref();
93
+ Self::load_from_files(
94
+ ModelFiles {
95
+ config: dir.join("config.json"),
96
+ dictionary: dir.join("alphabet-all-v7.txt"),
97
+ weights: dir.join("model.safetensors"),
98
+ },
99
+ use_cpu,
100
+ )
101
+ }
102
+
103
+ fn load_from_files(files: ModelFiles, use_cpu: bool) -> Result<Self> {
104
+ let device = device(use_cpu)?;
105
+ let config: Mit48pxConfig =
106
+ loading::read_json(&files.config).context("failed to parse mit48px config")?;
107
+ let dictionary = read_dictionary(&files.dictionary)?;
108
+ let data = std::fs::read(&files.weights)
109
+ .with_context(|| format!("failed to read {}", files.weights.display()))?;
110
+ let vb = VarBuilder::from_buffered_safetensors(data, DType::F32, &device)?;
111
+ let model = Mit48pxModel::new(config.clone(), dictionary.len(), vb, device.clone())?;
112
+ Ok(Self {
113
+ model,
114
+ config,
115
+ dictionary,
116
+ device,
117
+ })
118
+ }
119
+
120
+ #[instrument(level = "debug", skip_all)]
121
+ pub fn inference_regions(&self, regions: &[DynamicImage]) -> Result<Vec<Mit48pxPrediction>> {
122
+ if regions.is_empty() {
123
+ return Ok(Vec::new());
124
+ }
125
+
126
+ let mut predictions = Vec::with_capacity(regions.len());
127
+ for chunk in regions.chunks(OCR_CHUNK_SIZE) {
128
+ let batch = preprocess_regions(chunk, &self.config, &self.device)?;
129
+ let raw = self.model.infer_batch(&batch.tensor, &batch.widths)?;
130
+ for prediction in raw {
131
+ predictions.push(self.decode_prediction(prediction));
132
+ }
133
+ }
134
+ Ok(predictions)
135
+ }
136
+
137
+ #[instrument(level = "debug", skip_all)]
138
+ pub fn inference_text_blocks(
139
+ &self,
140
+ image: &DynamicImage,
141
+ blocks: &[TextBlock],
142
+ ) -> Result<Vec<Mit48pxBlockPrediction>> {
143
+ let mut regions = Vec::new();
144
+ let mut block_indices = Vec::new();
145
+ for (block_index, block) in blocks.iter().enumerate() {
146
+ for region in extract_text_block_regions(image, block) {
147
+ regions.push(region);
148
+ block_indices.push(block_index);
149
+ }
150
+ }
151
+
152
+ let line_predictions = self.inference_regions(&regions)?;
153
+ let mut grouped = vec![Vec::<Mit48pxPrediction>::new(); blocks.len()];
154
+ for (prediction, block_index) in line_predictions.into_iter().zip(block_indices) {
155
+ grouped[block_index].push(prediction);
156
+ }
157
+
158
+ let mut outputs = Vec::with_capacity(blocks.len());
159
+ for (block_index, lines) in grouped.into_iter().enumerate() {
160
+ if lines.is_empty() {
161
+ outputs.push(Mit48pxBlockPrediction {
162
+ block_index,
163
+ text: String::new(),
164
+ confidence: 0.0,
165
+ text_color: [0, 0, 0],
166
+ stroke_color: [0, 0, 0],
167
+ });
168
+ continue;
169
+ }
170
+
171
+ let text = lines
172
+ .iter()
173
+ .map(|line| line.text.as_str())
174
+ .collect::<Vec<_>>()
175
+ .join("\n");
176
+ let confidence =
177
+ lines.iter().map(|line| line.confidence).sum::<f32>() / lines.len() as f32;
178
+ let text_color = average_rgb(lines.iter().map(|line| line.text_color));
179
+ let stroke_color = average_rgb(lines.iter().map(|line| line.stroke_color));
180
+
181
+ outputs.push(Mit48pxBlockPrediction {
182
+ block_index,
183
+ text,
184
+ confidence,
185
+ text_color,
186
+ stroke_color,
187
+ });
188
+ }
189
+
190
+ Ok(outputs)
191
+ }
192
+
193
+ fn decode_prediction(&self, prediction: RawPrediction) -> Mit48pxPrediction {
194
+ let mut text = String::new();
195
+ let mut fg_sum = [0f32; 3];
196
+ let mut bg_sum = [0f32; 3];
197
+ let mut fg_count = 0usize;
198
+ let mut bg_count = 0usize;
199
+ let mut has_text_color = false;
200
+ let mut has_stroke_color = false;
201
+
202
+ let len = prediction
203
+ .token_ids
204
+ .len()
205
+ .min(prediction.fg_colors.len())
206
+ .min(prediction.bg_colors.len())
207
+ .min(prediction.fg_indicators.len())
208
+ .min(prediction.bg_indicators.len());
209
+
210
+ for index in 0..len {
211
+ let token_id = prediction.token_ids[index] as usize;
212
+ let token = self
213
+ .dictionary
214
+ .get(token_id)
215
+ .map(String::as_str)
216
+ .unwrap_or("<UNK>");
217
+ if token == "<S>" {
218
+ continue;
219
+ }
220
+ if token == "</S>" {
221
+ break;
222
+ }
223
+
224
+ if token == self.config.space_token {
225
+ text.push(' ');
226
+ } else {
227
+ text.push_str(token);
228
+ }
229
+
230
+ let fg = prediction.fg_colors[index];
231
+ let bg = prediction.bg_colors[index];
232
+ let fg_present =
233
+ prediction.fg_indicators[index][1] > prediction.fg_indicators[index][0];
234
+ let bg_present =
235
+ prediction.bg_indicators[index][1] > prediction.bg_indicators[index][0];
236
+ if fg_present {
237
+ has_text_color = true;
238
+ accumulate_rgb(&mut fg_sum, fg);
239
+ fg_count += 1;
240
+ }
241
+ if bg_present {
242
+ has_stroke_color = true;
243
+ accumulate_rgb(&mut bg_sum, bg);
244
+ bg_count += 1;
245
+ } else {
246
+ accumulate_rgb(&mut bg_sum, fg);
247
+ bg_count += 1;
248
+ }
249
+ }
250
+
251
+ Mit48pxPrediction {
252
+ text,
253
+ confidence: prediction.confidence,
254
+ text_color: finish_rgb(fg_sum, fg_count),
255
+ stroke_color: finish_rgb(bg_sum, bg_count),
256
+ has_text_color,
257
+ has_stroke_color,
258
+ }
259
+ }
260
+ }
261
+
262
+ fn read_dictionary(path: &Path) -> Result<Vec<String>> {
263
+ let data = std::fs::read_to_string(path)
264
+ .with_context(|| format!("failed to read {}", path.display()))?;
265
+ Ok(data
266
+ .lines()
267
+ .map(|line| line.trim_end_matches('\r').to_string())
268
+ .collect())
269
+ }
270
+
271
+ fn preprocess_regions(
272
+ regions: &[DynamicImage],
273
+ config: &Mit48pxConfig,
274
+ device: &Device,
275
+ ) -> Result<PreparedBatch> {
276
+ let mut resized = Vec::<RgbImage>::with_capacity(regions.len());
277
+ let mut widths = Vec::with_capacity(regions.len());
278
+ let mut max_width = 1u32;
279
+
280
+ for region in regions {
281
+ let region = resize_region(region, config.text_height, config.max_width);
282
+ max_width = max_width.max(region.width());
283
+ widths.push(region.width());
284
+ resized.push(region);
285
+ }
286
+
287
+ // The source checkpoint expects seven blank pixels before the ConvNeXt
288
+ // backbone. That extra slack affects the backbone feature width and therefore the
289
+ // encoder mask shape, so keep it byte-for-byte compatible instead of rounding to 4.
290
+ let padded_width = max_width.saturating_add(7);
291
+ let height = config.text_height as usize;
292
+ let width = padded_width as usize;
293
+ let mut flat = vec![-1.0f32; resized.len() * height * width * 3];
294
+
295
+ for (batch_index, image) in resized.iter().enumerate() {
296
+ for y in 0..image.height() as usize {
297
+ for x in 0..image.width() as usize {
298
+ let pixel = image.get_pixel(x as u32, y as u32).0;
299
+ let offset = ((batch_index * height + y) * width + x) * 3;
300
+ flat[offset] = pixel[0] as f32 / 127.5 - 1.0;
301
+ flat[offset + 1] = pixel[1] as f32 / 127.5 - 1.0;
302
+ flat[offset + 2] = pixel[2] as f32 / 127.5 - 1.0;
303
+ }
304
+ }
305
+ }
306
+
307
+ let tensor =
308
+ Tensor::from_vec(flat, (resized.len(), height, width, 3), device)?.permute((0, 3, 1, 2))?;
309
+ Ok(PreparedBatch { tensor, widths })
310
+ }
311
+
312
+ fn resize_region(region: &DynamicImage, text_height: u32, max_width: u32) -> RgbImage {
313
+ let rgb = region.to_rgb8();
314
+ let (width, height) = rgb.dimensions();
315
+ let new_width = ((width as f32 / height.max(1) as f32) * text_height as f32)
316
+ .round()
317
+ .clamp(1.0, max_width as f32) as u32;
318
+ if width == new_width && height == text_height {
319
+ rgb
320
+ } else {
321
+ image::imageops::resize(&rgb, new_width, text_height, FilterType::Triangle)
322
+ }
323
+ }
324
+
325
+ fn accumulate_rgb(sum: &mut [f32; 3], color: [f32; 3]) {
326
+ for (dst, src) in sum.iter_mut().zip(color) {
327
+ *dst += src * 255.0;
328
+ }
329
+ }
330
+
331
+ fn finish_rgb(sum: [f32; 3], count: usize) -> [u8; 3] {
332
+ if count == 0 {
333
+ return [0, 0, 0];
334
+ }
335
+ let denom = count as f32;
336
+ [
337
+ ((sum[0] / denom).round() as i32).clamp(0, 255) as u8,
338
+ ((sum[1] / denom).round() as i32).clamp(0, 255) as u8,
339
+ ((sum[2] / denom).round() as i32).clamp(0, 255) as u8,
340
+ ]
341
+ }
342
+
343
+ fn average_rgb(colors: impl Iterator<Item = [u8; 3]>) -> [u8; 3] {
344
+ let mut sum = [0f32; 3];
345
+ let mut count = 0usize;
346
+ for color in colors {
347
+ for (index, channel) in color.into_iter().enumerate() {
348
+ sum[index] += channel as f32;
349
+ }
350
+ count += 1;
351
+ }
352
+ if count == 0 {
353
+ return [0, 0, 0];
354
+ }
355
+ [
356
+ (sum[0] / count as f32).round().clamp(0.0, 255.0) as u8,
357
+ (sum[1] / count as f32).round().clamp(0.0, 255.0) as u8,
358
+ (sum[2] / count as f32).round().clamp(0.0, 255.0) as u8,
359
+ ]
360
+ }
361
+
362
+ #[cfg(test)]
363
+ mod tests {
364
+ use std::path::PathBuf;
365
+
366
+ use image::{DynamicImage, RgbImage};
367
+
368
+ use super::{Mit48pxConfig, Mit48pxPrediction, finish_rgb, preprocess_regions};
369
+
370
+ fn test_config() -> Mit48pxConfig {
371
+ Mit48pxConfig {
372
+ text_height: 48,
373
+ max_width: 8100,
374
+ embd_dim: 320,
375
+ num_heads: 4,
376
+ encoder_layers: 4,
377
+ decoder_layers: 5,
378
+ beam_size_default: 5,
379
+ max_seq_length_default: 255,
380
+ pad_token_id: 0,
381
+ bos_token_id: 1,
382
+ eos_token_id: 2,
383
+ space_token: "<SP>".to_string(),
384
+ dictionary_file: "alphabet-all-v7.txt".to_string(),
385
+ }
386
+ }
387
+
388
+ #[test]
389
+ fn preprocessing_resizes_to_48px_and_matches_ballonstranslator_width_padding()
390
+ -> anyhow::Result<()> {
391
+ let image = DynamicImage::ImageRgb8(RgbImage::from_pixel(25, 10, image::Rgb([255, 0, 0])));
392
+ let batch = preprocess_regions(&[image], &test_config(), &candle_core::Device::Cpu)?;
393
+ assert_eq!(batch.widths, vec![120]);
394
+ assert_eq!(batch.tensor.dims(), &[1, 3, 48, 127]);
395
+ Ok(())
396
+ }
397
+
398
+ #[test]
399
+ fn finish_rgb_clamps_to_u8_range() {
400
+ assert_eq!(finish_rgb([300.0, 40.0, -10.0], 1), [255, 40, 0]);
401
+ }
402
+
403
+ #[test]
404
+ fn block_prediction_shape_remains_serializable() -> anyhow::Result<()> {
405
+ let prediction = Mit48pxPrediction {
406
+ text: "abc".to_string(),
407
+ confidence: 0.5,
408
+ text_color: [1, 2, 3],
409
+ stroke_color: [4, 5, 6],
410
+ has_text_color: true,
411
+ has_stroke_color: false,
412
+ };
413
+ let json = serde_json::to_string(&prediction)?;
414
+ assert!(json.contains("\"hasTextColor\":true"));
415
+ Ok(())
416
+ }
417
+
418
+ #[test]
419
+ #[ignore]
420
+ fn local_model_dir_loads_and_ocrs_a_crop() -> anyhow::Result<()> {
421
+ let model_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
422
+ .join("..")
423
+ .join("target/mit48px-local");
424
+ if !model_dir.exists() {
425
+ anyhow::bail!("missing local mit48px assets at {}", model_dir.display());
426
+ }
427
+
428
+ let model = super::Mit48pxOcr::load_from_dir(&model_dir, true)?;
429
+ let image = image::open(
430
+ PathBuf::from(env!("CARGO_MANIFEST_DIR"))
431
+ .join("..")
432
+ .join("data/bluearchive_comics/1.jpg"),
433
+ )?;
434
+ let crop = image.crop_imm(66, 26, 270, 48);
435
+ let output = model.inference_regions(&[crop])?;
436
+ assert_eq!(output.len(), 1);
437
+ assert!(!output[0].text.is_empty());
438
+ Ok(())
439
+ }
440
+
441
+ #[test]
442
+ #[ignore]
443
+ fn local_model_matches_reference_text_on_known_crop() -> anyhow::Result<()> {
444
+ let model_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
445
+ .join("..")
446
+ .join("target/mit48px-local");
447
+ if !model_dir.exists() {
448
+ anyhow::bail!("missing local mit48px assets at {}", model_dir.display());
449
+ }
450
+
451
+ let model = super::Mit48pxOcr::load_from_dir(&model_dir, true)?;
452
+ let image = image::open(
453
+ PathBuf::from(env!("CARGO_MANIFEST_DIR"))
454
+ .join("..")
455
+ .join("data/140817417_p0.jpg"),
456
+ )?;
457
+ let crop = image.crop_imm(48, 232, 1172, 388);
458
+ let output = model.inference_regions(&[crop])?;
459
+ assert_eq!(output.len(), 1);
460
+ assert_eq!(output[0].text, "デカグラマトン戦闘");
461
+ Ok(())
462
+ }
463
+ }
koharu-ml/src/mit48px_ocr/model.rs ADDED
@@ -0,0 +1,1014 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use anyhow::Result;
2
+ use candle_core::{D, DType, Device, Tensor};
3
+ use candle_nn::{
4
+ BatchNorm, Conv1d, Conv1dConfig, Conv2d, Conv2dConfig, Embedding, LayerNorm, Linear, Module,
5
+ ModuleT, VarBuilder, conv2d, embedding, layer_norm,
6
+ };
7
+
8
+ use super::Mit48pxConfig;
9
+
10
+ const LAYER_NORM_EPS: f64 = 1e-5;
11
+ const MAX_FINISHED_HYPOS: usize = 2;
12
+ type TopkOutput = (Vec<Vec<f32>>, Vec<Vec<u32>>);
13
+
14
+ #[derive(Debug, Clone)]
15
+ pub(crate) struct RawPrediction {
16
+ pub token_ids: Vec<u32>,
17
+ pub confidence: f32,
18
+ pub fg_colors: Vec<[f32; 3]>,
19
+ pub bg_colors: Vec<[f32; 3]>,
20
+ pub fg_indicators: Vec<[f32; 2]>,
21
+ pub bg_indicators: Vec<[f32; 2]>,
22
+ }
23
+
24
+ pub(crate) struct Mit48pxModel {
25
+ config: Mit48pxConfig,
26
+ backbone: ConvNextFeatureExtractor,
27
+ encoders: Vec<TransformerEncoderLayer>,
28
+ decoders: Vec<TransformerDecoderLayer>,
29
+ embedding: Embedding,
30
+ pred1: Linear,
31
+ pred: Linear,
32
+ color_pred1: Linear,
33
+ color_pred_fg: Linear,
34
+ color_pred_bg: Linear,
35
+ color_pred_fg_ind: Linear,
36
+ color_pred_bg_ind: Linear,
37
+ device: Device,
38
+ }
39
+
40
+ #[derive(Clone)]
41
+ struct Hypothesis {
42
+ sample_index: usize,
43
+ token_ids: Vec<u32>,
44
+ sum_logprob: f32,
45
+ cached_activations: Vec<Tensor>,
46
+ }
47
+
48
+ fn topk_last_dim(tensor: &Tensor, topk: usize) -> Result<TopkOutput> {
49
+ let rows = tensor.to_vec2::<f32>()?;
50
+ let mut values = Vec::with_capacity(rows.len());
51
+ let mut indices = Vec::with_capacity(rows.len());
52
+
53
+ for row in rows {
54
+ let mut ranked = row.into_iter().enumerate().collect::<Vec<_>>();
55
+ ranked.sort_by(|(left_idx, left), (right_idx, right)| {
56
+ right.total_cmp(left).then_with(|| left_idx.cmp(right_idx))
57
+ });
58
+ ranked.truncate(topk);
59
+ values.push(ranked.iter().map(|(_, value)| *value).collect());
60
+ indices.push(ranked.into_iter().map(|(index, _)| index as u32).collect());
61
+ }
62
+
63
+ Ok((values, indices))
64
+ }
65
+
66
+ fn cat_batch(tensors: &[Tensor]) -> Result<Tensor> {
67
+ let refs = tensors.iter().collect::<Vec<_>>();
68
+ Ok(Tensor::cat(&refs, 0)?)
69
+ }
70
+
71
+ fn load_linear(vb: VarBuilder, in_dim: usize, out_dim: usize) -> Result<Linear> {
72
+ Ok(Linear::new(
73
+ vb.get((out_dim, in_dim), "weight")?,
74
+ Some(vb.get(out_dim, "bias")?),
75
+ ))
76
+ }
77
+
78
+ fn load_batch_norm(vb: VarBuilder, channels: usize) -> Result<BatchNorm> {
79
+ Ok(BatchNorm::new(
80
+ channels,
81
+ vb.get(channels, "running_mean")?,
82
+ vb.get(channels, "running_var")?,
83
+ vb.get(channels, "weight")?,
84
+ vb.get(channels, "bias")?,
85
+ 1e-5,
86
+ )?)
87
+ }
88
+
89
+ impl Hypothesis {
90
+ fn new(
91
+ sample_index: usize,
92
+ bos_token_id: u32,
93
+ decoder_layers: usize,
94
+ embd_dim: usize,
95
+ device: &Device,
96
+ ) -> Result<Self> {
97
+ let mut cached_activations = Vec::with_capacity(decoder_layers + 1);
98
+ for _ in 0..=decoder_layers {
99
+ cached_activations.push(Tensor::zeros((1, 0, embd_dim), DType::F32, device)?);
100
+ }
101
+ Ok(Self {
102
+ sample_index,
103
+ token_ids: vec![bos_token_id],
104
+ sum_logprob: 0.0,
105
+ cached_activations,
106
+ })
107
+ }
108
+
109
+ fn decoded_len(&self) -> usize {
110
+ self.token_ids.len().saturating_sub(1)
111
+ }
112
+
113
+ fn avg_logprob(&self) -> f32 {
114
+ let len = self.decoded_len().max(1) as f32;
115
+ self.sum_logprob / len
116
+ }
117
+
118
+ fn probability(&self) -> f32 {
119
+ self.avg_logprob().exp()
120
+ }
121
+
122
+ fn last_token(&self) -> u32 {
123
+ *self.token_ids.last().expect("hypothesis has bos token")
124
+ }
125
+
126
+ fn seq_end(&self, eos_token_id: u32) -> bool {
127
+ self.last_token() == eos_token_id
128
+ }
129
+
130
+ fn extend(&self, token_id: u32, logprob: f32) -> Self {
131
+ let mut token_ids = self.token_ids.clone();
132
+ token_ids.push(token_id);
133
+ Self {
134
+ sample_index: self.sample_index,
135
+ token_ids,
136
+ sum_logprob: self.sum_logprob + logprob,
137
+ cached_activations: self.cached_activations.to_vec(),
138
+ }
139
+ }
140
+
141
+ fn output(&self) -> &Tensor {
142
+ self.cached_activations
143
+ .last()
144
+ .expect("decoder output cache exists")
145
+ }
146
+
147
+ fn score_cmp(a: &Self, b: &Self) -> std::cmp::Ordering {
148
+ a.avg_logprob().total_cmp(&b.avg_logprob())
149
+ }
150
+
151
+ fn descending(a: &Self, b: &Self) -> std::cmp::Ordering {
152
+ b.avg_logprob().total_cmp(&a.avg_logprob())
153
+ }
154
+ }
155
+
156
+ impl Mit48pxModel {
157
+ pub(crate) fn new(
158
+ config: Mit48pxConfig,
159
+ vocab_size: usize,
160
+ vb: VarBuilder,
161
+ device: Device,
162
+ ) -> Result<Self> {
163
+ let backbone = ConvNextFeatureExtractor::new(vb.pp("backbone"))?;
164
+ let encoders = (0..config.encoder_layers)
165
+ .map(|index| TransformerEncoderLayer::new(vb.pp(format!("encoders.{index}"))))
166
+ .collect::<Result<Vec<_>>>()?;
167
+ let decoders = (0..config.decoder_layers)
168
+ .map(|index| TransformerDecoderLayer::new(vb.pp(format!("decoders.{index}"))))
169
+ .collect::<Result<Vec<_>>>()?;
170
+ let embedding = embedding(vocab_size, config.embd_dim, vb.pp("embd"))?;
171
+ let pred1 = load_linear(vb.pp("pred1.0"), config.embd_dim, config.embd_dim)?;
172
+ let pred = load_linear(vb.pp("pred"), config.embd_dim, vocab_size)?;
173
+ let color_pred1 = load_linear(vb.pp("color_pred1.0"), config.embd_dim, 64)?;
174
+ let color_pred_fg = load_linear(vb.pp("color_pred_fg"), 64, 3)?;
175
+ let color_pred_bg = load_linear(vb.pp("color_pred_bg"), 64, 3)?;
176
+ let color_pred_fg_ind = load_linear(vb.pp("color_pred_fg_ind"), 64, 2)?;
177
+ let color_pred_bg_ind = load_linear(vb.pp("color_pred_bg_ind"), 64, 2)?;
178
+
179
+ Ok(Self {
180
+ config,
181
+ backbone,
182
+ encoders,
183
+ decoders,
184
+ embedding,
185
+ pred1,
186
+ pred,
187
+ color_pred1,
188
+ color_pred_fg,
189
+ color_pred_bg,
190
+ color_pred_fg_ind,
191
+ color_pred_bg_ind,
192
+ device,
193
+ })
194
+ }
195
+
196
+ pub(crate) fn infer_batch(
197
+ &self,
198
+ images: &Tensor,
199
+ image_widths: &[u32],
200
+ ) -> Result<Vec<RawPrediction>> {
201
+ let (memory, memory_mask) = self.encode(images, image_widths)?;
202
+ let batch_size = images.dim(0)?;
203
+ let beam_size = self.config.beam_size_default.max(1);
204
+ let max_seq_length = self.config.max_seq_length_default.max(1);
205
+ let bos = self.config.bos_token_id;
206
+ let eos = self.config.eos_token_id;
207
+
208
+ let mut finished = vec![Vec::<Hypothesis>::new(); batch_size];
209
+ let mut best_fallback = vec![None::<Hypothesis>; batch_size];
210
+
211
+ let mut seed_hyps = (0..batch_size)
212
+ .map(|sample_index| {
213
+ Hypothesis::new(
214
+ sample_index,
215
+ bos,
216
+ self.decoders.len(),
217
+ self.config.embd_dim,
218
+ &self.device,
219
+ )
220
+ })
221
+ .collect::<Result<Vec<_>>>()?;
222
+
223
+ let decoded = self.next_token_batch(&mut seed_hyps, &memory, &memory_mask)?;
224
+ let (values, indices) = self.next_token_candidates(&decoded, beam_size)?;
225
+ let mut active = Vec::with_capacity(batch_size * beam_size);
226
+ for sample_index in 0..batch_size {
227
+ let mut candidates = Vec::with_capacity(beam_size);
228
+ for beam_index in 0..beam_size {
229
+ candidates.push(seed_hyps[sample_index].extend(
230
+ indices[sample_index][beam_index],
231
+ values[sample_index][beam_index],
232
+ ));
233
+ }
234
+ candidates.sort_by(Hypothesis::descending);
235
+ best_fallback[sample_index] = candidates.first().cloned();
236
+ let mut kept_active = 0usize;
237
+ for candidate in candidates {
238
+ if candidate.seq_end(eos) {
239
+ finished[sample_index].push(candidate);
240
+ if finished[sample_index].len() >= MAX_FINISHED_HYPOS {
241
+ break;
242
+ }
243
+ } else if kept_active < beam_size {
244
+ kept_active += 1;
245
+ active.push(candidate);
246
+ }
247
+ }
248
+ }
249
+
250
+ for _step in 1..max_seq_length {
251
+ if active.is_empty() {
252
+ break;
253
+ }
254
+
255
+ let decoded = self.next_token_batch(&mut active, &memory, &memory_mask)?;
256
+ let (values, indices) = self.next_token_candidates(&decoded, beam_size)?;
257
+
258
+ let mut per_sample = vec![Vec::<Hypothesis>::new(); batch_size];
259
+ for (hyp_index, hypothesis) in active.iter().enumerate() {
260
+ for beam_index in 0..beam_size {
261
+ per_sample[hypothesis.sample_index].push(hypothesis.extend(
262
+ indices[hyp_index][beam_index],
263
+ values[hyp_index][beam_index],
264
+ ));
265
+ }
266
+ }
267
+
268
+ active.clear();
269
+ for sample_index in 0..batch_size {
270
+ if per_sample[sample_index].is_empty() {
271
+ continue;
272
+ }
273
+ per_sample[sample_index].sort_by(Hypothesis::descending);
274
+ best_fallback[sample_index] = per_sample[sample_index].first().cloned();
275
+
276
+ if finished[sample_index].len() >= MAX_FINISHED_HYPOS {
277
+ continue;
278
+ }
279
+
280
+ let mut kept_active = 0usize;
281
+ for candidate in per_sample[sample_index].drain(..) {
282
+ if candidate.seq_end(eos) {
283
+ finished[sample_index].push(candidate);
284
+ if finished[sample_index].len() >= MAX_FINISHED_HYPOS {
285
+ break;
286
+ }
287
+ } else if kept_active < beam_size {
288
+ kept_active += 1;
289
+ active.push(candidate);
290
+ }
291
+ }
292
+ }
293
+ }
294
+
295
+ let mut outputs = Vec::with_capacity(batch_size);
296
+ for sample_index in 0..batch_size {
297
+ let best = if finished[sample_index].is_empty() {
298
+ best_fallback[sample_index]
299
+ .clone()
300
+ .or_else(|| {
301
+ active
302
+ .iter()
303
+ .filter(|hyp| hyp.sample_index == sample_index)
304
+ .cloned()
305
+ .max_by(Hypothesis::score_cmp)
306
+ })
307
+ .ok_or_else(|| {
308
+ anyhow::anyhow!("no beam hypothesis for sample {sample_index}")
309
+ })?
310
+ } else {
311
+ finished[sample_index]
312
+ .iter()
313
+ .cloned()
314
+ .max_by(Hypothesis::score_cmp)
315
+ .expect("non-empty finished")
316
+ };
317
+ outputs.push(self.build_raw_prediction(&best)?);
318
+ }
319
+
320
+ Ok(outputs)
321
+ }
322
+
323
+ fn encode(&self, images: &Tensor, image_widths: &[u32]) -> Result<(Tensor, Tensor)> {
324
+ let mut memory = self.backbone.forward(images)?;
325
+ let (_, _, height, width) = memory.dims4()?;
326
+ anyhow::ensure!(height == 1, "unexpected backbone height: {height}");
327
+ memory = memory.squeeze(2)?.transpose(1, 2)?;
328
+
329
+ let mut mask_values = vec![0u8; image_widths.len() * width];
330
+ for (batch_index, width_px) in image_widths.iter().enumerate() {
331
+ let valid_len = ((*width_px as usize).div_ceil(4) + 2).min(width);
332
+ for pos in valid_len..width {
333
+ mask_values[batch_index * width + pos] = 1;
334
+ }
335
+ }
336
+ let memory_mask = Tensor::from_vec(mask_values, (image_widths.len(), width), &self.device)?;
337
+ for layer in &self.encoders {
338
+ memory = layer.forward(&memory, Some(&memory_mask))?;
339
+ }
340
+ Ok((memory, memory_mask))
341
+ }
342
+
343
+ fn next_token_batch(
344
+ &self,
345
+ hyps: &mut [Hypothesis],
346
+ memory: &Tensor,
347
+ memory_mask: &Tensor,
348
+ ) -> Result<Tensor> {
349
+ let offset = hyps.first().map(Hypothesis::decoded_len).unwrap_or(0);
350
+ let batch = hyps.len();
351
+ let sample_indices = hyps
352
+ .iter()
353
+ .map(|hyp| hyp.sample_index as u32)
354
+ .collect::<Vec<_>>();
355
+ let sample_indices = Tensor::from_vec(sample_indices, (batch,), &self.device)?;
356
+ let selected_memory = memory.index_select(&sample_indices, 0)?;
357
+ let selected_mask = memory_mask.index_select(&sample_indices, 0)?;
358
+
359
+ let last_tokens = hyps.iter().map(Hypothesis::last_token).collect::<Vec<_>>();
360
+ let last_tokens = Tensor::from_vec(last_tokens, (batch,), &self.device)?;
361
+ let mut tgt =
362
+ self.embedding
363
+ .forward(&last_tokens)?
364
+ .reshape((batch, 1, self.config.embd_dim))?;
365
+
366
+ for (layer_index, layer) in self.decoders.iter().enumerate() {
367
+ let previous = if offset == 0 {
368
+ None
369
+ } else {
370
+ let refs = hyps
371
+ .iter()
372
+ .map(|hyp| hyp.cached_activations[layer_index].clone())
373
+ .collect::<Vec<_>>();
374
+ Some(cat_batch(&refs)?)
375
+ };
376
+ let combined = if let Some(previous) = previous {
377
+ Tensor::cat(&[&previous, &tgt], 1)?
378
+ } else {
379
+ tgt.clone()
380
+ };
381
+ for (hyp_index, hyp) in hyps.iter_mut().enumerate() {
382
+ hyp.cached_activations[layer_index] = combined.narrow(0, hyp_index, 1)?;
383
+ }
384
+ tgt =
385
+ layer.forward_cached(&tgt, &combined, &selected_memory, &selected_mask, offset)?;
386
+ }
387
+
388
+ for (hyp_index, hyp) in hyps.iter_mut().enumerate() {
389
+ let current = tgt.narrow(0, hyp_index, 1)?;
390
+ hyp.cached_activations[self.decoders.len()] = if offset == 0 {
391
+ current
392
+ } else {
393
+ Tensor::cat(&[&hyp.cached_activations[self.decoders.len()], &current], 1)?
394
+ };
395
+ }
396
+
397
+ Ok(tgt.squeeze(1)?)
398
+ }
399
+
400
+ fn next_token_candidates(&self, decoded: &Tensor, beam_size: usize) -> Result<TopkOutput> {
401
+ let pred_feats = self.pred1.forward(decoded)?.gelu_erf()?;
402
+ let logits = self.pred.forward(&pred_feats)?;
403
+ let log_probs = candle_nn::ops::log_softmax(&logits, D::Minus1)?;
404
+ topk_last_dim(&log_probs, beam_size)
405
+ }
406
+
407
+ fn build_raw_prediction(&self, hypothesis: &Hypothesis) -> Result<RawPrediction> {
408
+ let decoded = hypothesis.output();
409
+ let color_feats = self.color_pred1.forward(decoded)?.relu()?;
410
+ let fg_colors = self
411
+ .color_pred_fg
412
+ .forward(&color_feats)?
413
+ .squeeze(0)?
414
+ .to_vec2::<f32>()?
415
+ .into_iter()
416
+ .map(|row| [row[0], row[1], row[2]])
417
+ .collect();
418
+ let bg_colors = self
419
+ .color_pred_bg
420
+ .forward(&color_feats)?
421
+ .squeeze(0)?
422
+ .to_vec2::<f32>()?
423
+ .into_iter()
424
+ .map(|row| [row[0], row[1], row[2]])
425
+ .collect();
426
+ let fg_indicators = self
427
+ .color_pred_fg_ind
428
+ .forward(&color_feats)?
429
+ .squeeze(0)?
430
+ .to_vec2::<f32>()?
431
+ .into_iter()
432
+ .map(|row| [row[0], row[1]])
433
+ .collect();
434
+ let bg_indicators = self
435
+ .color_pred_bg_ind
436
+ .forward(&color_feats)?
437
+ .squeeze(0)?
438
+ .to_vec2::<f32>()?
439
+ .into_iter()
440
+ .map(|row| [row[0], row[1]])
441
+ .collect();
442
+
443
+ Ok(RawPrediction {
444
+ token_ids: hypothesis.token_ids[1..].to_vec(),
445
+ confidence: hypothesis.probability(),
446
+ fg_colors,
447
+ bg_colors,
448
+ fg_indicators,
449
+ bg_indicators,
450
+ })
451
+ }
452
+ }
453
+
454
+ struct ConvBnRelu2d {
455
+ conv: Conv2d,
456
+ bn: BatchNorm,
457
+ }
458
+
459
+ impl ConvBnRelu2d {
460
+ fn new(
461
+ vb: VarBuilder,
462
+ in_channels: usize,
463
+ out_channels: usize,
464
+ kernel: usize,
465
+ stride: usize,
466
+ padding: usize,
467
+ ) -> Result<Self> {
468
+ let conv = conv2d(
469
+ in_channels,
470
+ out_channels,
471
+ kernel,
472
+ Conv2dConfig {
473
+ stride,
474
+ padding,
475
+ ..Default::default()
476
+ },
477
+ vb.pp("0"),
478
+ )?;
479
+ let bn = load_batch_norm(vb.pp("1"), out_channels)?;
480
+ Ok(Self { conv, bn })
481
+ }
482
+
483
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
484
+ let xs = self.conv.forward(xs)?;
485
+ let xs = self.bn.forward_t(&xs, false)?;
486
+ Ok(xs.relu()?)
487
+ }
488
+ }
489
+
490
+ struct HeightConv {
491
+ conv: Conv1d,
492
+ out_channels: usize,
493
+ }
494
+
495
+ impl HeightConv {
496
+ fn new(
497
+ vb: VarBuilder,
498
+ in_channels: usize,
499
+ out_channels: usize,
500
+ kernel: usize,
501
+ stride: usize,
502
+ ) -> Result<Self> {
503
+ let weight = vb
504
+ .get((out_channels, in_channels, kernel, 1), "weight")?
505
+ .reshape((out_channels, in_channels, kernel))?;
506
+ let bias = vb.get(out_channels, "bias")?;
507
+ let conv = Conv1d::new(
508
+ weight,
509
+ Some(bias),
510
+ Conv1dConfig {
511
+ stride,
512
+ ..Default::default()
513
+ },
514
+ );
515
+ Ok(Self { conv, out_channels })
516
+ }
517
+
518
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
519
+ let (batch, channels, height, width) = xs.dims4()?;
520
+ let reshaped = xs
521
+ .permute((0, 3, 1, 2))?
522
+ .reshape((batch * width, channels, height))?;
523
+ let ys = self.conv.forward(&reshaped)?;
524
+ let out_height = ys.dim(2)?;
525
+ Ok(ys
526
+ .reshape((batch, width, self.out_channels, out_height))?
527
+ .permute((0, 2, 3, 1))?)
528
+ }
529
+ }
530
+
531
+ struct HeightConvBnRelu {
532
+ conv: HeightConv,
533
+ bn: BatchNorm,
534
+ }
535
+
536
+ impl HeightConvBnRelu {
537
+ fn new(
538
+ vb: VarBuilder,
539
+ in_channels: usize,
540
+ out_channels: usize,
541
+ kernel: usize,
542
+ stride: usize,
543
+ ) -> Result<Self> {
544
+ let conv = HeightConv::new(vb.pp("0"), in_channels, out_channels, kernel, stride)?;
545
+ let bn = load_batch_norm(vb.pp("1"), out_channels)?;
546
+ Ok(Self { conv, bn })
547
+ }
548
+
549
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
550
+ let xs = self.conv.forward(xs)?;
551
+ let xs = self.bn.forward_t(&xs, false)?;
552
+ Ok(xs.relu()?)
553
+ }
554
+ }
555
+
556
+ struct ConvNeXtBlock {
557
+ dwconv: Conv2d,
558
+ norm: BatchNorm,
559
+ pwconv1: Conv2d,
560
+ pwconv2: Conv2d,
561
+ gamma: Tensor,
562
+ }
563
+
564
+ impl ConvNeXtBlock {
565
+ fn new(vb: VarBuilder, dim: usize, kernel: usize, padding: usize) -> Result<Self> {
566
+ let dwconv = conv2d(
567
+ dim,
568
+ dim,
569
+ kernel,
570
+ Conv2dConfig {
571
+ padding,
572
+ groups: dim,
573
+ ..Default::default()
574
+ },
575
+ vb.pp("dwconv"),
576
+ )?;
577
+ let norm = load_batch_norm(vb.pp("norm"), dim)?;
578
+ let pwconv1 = conv2d(dim, dim * 4, 1, Conv2dConfig::default(), vb.pp("pwconv1"))?;
579
+ let pwconv2 = conv2d(dim * 4, dim, 1, Conv2dConfig::default(), vb.pp("pwconv2"))?;
580
+ let gamma = vb.get((1, dim, 1, 1), "gamma")?;
581
+ Ok(Self {
582
+ dwconv,
583
+ norm,
584
+ pwconv1,
585
+ pwconv2,
586
+ gamma,
587
+ })
588
+ }
589
+
590
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
591
+ let residual = xs;
592
+ let xs = self.dwconv.forward(xs)?;
593
+ let xs = self.norm.forward_t(&xs, false)?;
594
+ let xs = self.pwconv1.forward(&xs)?.gelu_erf()?;
595
+ let xs = self.pwconv2.forward(&xs)?;
596
+ Ok(residual.broadcast_add(&xs.broadcast_mul(&self.gamma)?)?)
597
+ }
598
+ }
599
+
600
+ struct ConvNextFeatureExtractor {
601
+ stem0: Conv2d,
602
+ stem1: BatchNorm,
603
+ stem2: Conv2d,
604
+ stem3: BatchNorm,
605
+ stem4: Conv2d,
606
+ stem5: BatchNorm,
607
+ block1: Vec<ConvNeXtBlock>,
608
+ down1: ConvBnRelu2d,
609
+ block2: Vec<ConvNeXtBlock>,
610
+ down2: HeightConvBnRelu,
611
+ block3: Vec<ConvNeXtBlock>,
612
+ down3: HeightConvBnRelu,
613
+ block4: Vec<ConvNeXtBlock>,
614
+ down4: HeightConvBnRelu,
615
+ }
616
+
617
+ impl ConvNextFeatureExtractor {
618
+ fn new(vb: VarBuilder) -> Result<Self> {
619
+ let stem0 = conv2d(
620
+ 3,
621
+ 40,
622
+ 7,
623
+ Conv2dConfig {
624
+ padding: 3,
625
+ ..Default::default()
626
+ },
627
+ vb.pp("stem.0"),
628
+ )?;
629
+ let stem1 = load_batch_norm(vb.pp("stem.1"), 40)?;
630
+ let stem2 = conv2d(
631
+ 40,
632
+ 80,
633
+ 2,
634
+ Conv2dConfig {
635
+ stride: 2,
636
+ ..Default::default()
637
+ },
638
+ vb.pp("stem.3"),
639
+ )?;
640
+ let stem3 = load_batch_norm(vb.pp("stem.4"), 80)?;
641
+ let stem4 = conv2d(
642
+ 80,
643
+ 80,
644
+ 3,
645
+ Conv2dConfig {
646
+ padding: 1,
647
+ ..Default::default()
648
+ },
649
+ vb.pp("stem.6"),
650
+ )?;
651
+ let stem5 = load_batch_norm(vb.pp("stem.7"), 80)?;
652
+
653
+ Ok(Self {
654
+ stem0,
655
+ stem1,
656
+ stem2,
657
+ stem3,
658
+ stem4,
659
+ stem5,
660
+ block1: make_convnext_layers(vb.pp("block1"), 80, 4, 7, 3)?,
661
+ down1: ConvBnRelu2d::new(vb.pp("down1"), 80, 160, 2, 2, 0)?,
662
+ block2: make_convnext_layers(vb.pp("block2"), 160, 12, 7, 3)?,
663
+ down2: HeightConvBnRelu::new(vb.pp("down2"), 160, 320, 2, 2)?,
664
+ block3: make_convnext_layers(vb.pp("block3"), 320, 10, 5, 2)?,
665
+ down3: HeightConvBnRelu::new(vb.pp("down3"), 320, 320, 2, 2)?,
666
+ block4: make_convnext_layers(vb.pp("block4"), 320, 8, 3, 1)?,
667
+ down4: HeightConvBnRelu::new(vb.pp("down4"), 320, 320, 3, 1)?,
668
+ })
669
+ }
670
+
671
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
672
+ let mut xs = self.stem0.forward(xs)?;
673
+ xs = self.stem1.forward_t(&xs, false)?.relu()?;
674
+ xs = self.stem2.forward(&xs)?;
675
+ xs = self.stem3.forward_t(&xs, false)?.relu()?;
676
+ xs = self.stem4.forward(&xs)?;
677
+ xs = self.stem5.forward_t(&xs, false)?.relu()?;
678
+
679
+ for block in &self.block1 {
680
+ xs = block.forward(&xs)?;
681
+ }
682
+ xs = self.down1.forward(&xs)?;
683
+ for block in &self.block2 {
684
+ xs = block.forward(&xs)?;
685
+ }
686
+ xs = self.down2.forward(&xs)?;
687
+ for block in &self.block3 {
688
+ xs = block.forward(&xs)?;
689
+ }
690
+ xs = self.down3.forward(&xs)?;
691
+ for block in &self.block4 {
692
+ xs = block.forward(&xs)?;
693
+ }
694
+ self.down4.forward(&xs)
695
+ }
696
+ }
697
+
698
+ fn make_convnext_layers(
699
+ vb: VarBuilder,
700
+ dim: usize,
701
+ count: usize,
702
+ kernel: usize,
703
+ padding: usize,
704
+ ) -> Result<Vec<ConvNeXtBlock>> {
705
+ (0..count)
706
+ .map(|index| ConvNeXtBlock::new(vb.pp(index.to_string()), dim, kernel, padding))
707
+ .collect()
708
+ }
709
+
710
+ struct TransformerEncoderLayer {
711
+ self_attn: XposMultiheadAttention,
712
+ linear1: Linear,
713
+ linear2: Linear,
714
+ norm1: LayerNorm,
715
+ norm2: LayerNorm,
716
+ }
717
+
718
+ impl TransformerEncoderLayer {
719
+ fn new(vb: VarBuilder) -> Result<Self> {
720
+ Ok(Self {
721
+ self_attn: XposMultiheadAttention::new(vb.pp("self_attn"), 320, 4)?,
722
+ linear1: load_linear(vb.pp("linear1"), 320, 2048)?,
723
+ linear2: load_linear(vb.pp("linear2"), 2048, 320)?,
724
+ norm1: layer_norm(320, LAYER_NORM_EPS, vb.pp("norm1"))?,
725
+ norm2: layer_norm(320, LAYER_NORM_EPS, vb.pp("norm2"))?,
726
+ })
727
+ }
728
+
729
+ fn forward(&self, src: &Tensor, src_key_padding_mask: Option<&Tensor>) -> Result<Tensor> {
730
+ let sa_input = self.norm1.forward(src)?;
731
+ let sa =
732
+ self.self_attn
733
+ .forward(&sa_input, &sa_input, &sa_input, src_key_padding_mask, 0, 0)?;
734
+ let src = src.broadcast_add(&sa)?;
735
+ let ff_input = self.norm2.forward(&src)?;
736
+ let ff = self
737
+ .linear2
738
+ .forward(&self.linear1.forward(&ff_input)?.relu()?)?;
739
+ Ok(src.broadcast_add(&ff)?)
740
+ }
741
+ }
742
+
743
+ struct TransformerDecoderLayer {
744
+ self_attn: XposMultiheadAttention,
745
+ multihead_attn: XposMultiheadAttention,
746
+ linear1: Linear,
747
+ linear2: Linear,
748
+ norm1: LayerNorm,
749
+ norm2: LayerNorm,
750
+ norm3: LayerNorm,
751
+ }
752
+
753
+ impl TransformerDecoderLayer {
754
+ fn new(vb: VarBuilder) -> Result<Self> {
755
+ Ok(Self {
756
+ self_attn: XposMultiheadAttention::new(vb.pp("self_attn"), 320, 4)?,
757
+ multihead_attn: XposMultiheadAttention::new(vb.pp("multihead_attn"), 320, 4)?,
758
+ linear1: load_linear(vb.pp("linear1"), 320, 2048)?,
759
+ linear2: load_linear(vb.pp("linear2"), 2048, 320)?,
760
+ norm1: layer_norm(320, LAYER_NORM_EPS, vb.pp("norm1"))?,
761
+ norm2: layer_norm(320, LAYER_NORM_EPS, vb.pp("norm2"))?,
762
+ norm3: layer_norm(320, LAYER_NORM_EPS, vb.pp("norm3"))?,
763
+ })
764
+ }
765
+
766
+ fn forward_cached(
767
+ &self,
768
+ tgt: &Tensor,
769
+ combined_activations: &Tensor,
770
+ memory: &Tensor,
771
+ memory_mask: &Tensor,
772
+ q_offset: usize,
773
+ ) -> Result<Tensor> {
774
+ let tgt_norm = self.norm1.forward(tgt)?;
775
+ let combined_norm = self.norm1.forward(combined_activations)?;
776
+ let self_attn =
777
+ self.self_attn
778
+ .forward(&tgt_norm, &combined_norm, &combined_norm, None, 0, q_offset)?;
779
+ let tgt = tgt.broadcast_add(&self_attn)?;
780
+
781
+ let cross_attn = self.multihead_attn.forward(
782
+ &self.norm2.forward(&tgt)?,
783
+ memory,
784
+ memory,
785
+ Some(memory_mask),
786
+ 0,
787
+ q_offset,
788
+ )?;
789
+ let tgt = tgt.broadcast_add(&cross_attn)?;
790
+
791
+ let ff = self
792
+ .linear2
793
+ .forward(&self.linear1.forward(&self.norm3.forward(&tgt)?)?.relu()?)?;
794
+ Ok(tgt.broadcast_add(&ff)?)
795
+ }
796
+ }
797
+
798
+ struct XposMultiheadAttention {
799
+ k_proj: Linear,
800
+ v_proj: Linear,
801
+ q_proj: Linear,
802
+ out_proj: Linear,
803
+ xpos: Xpos,
804
+ num_heads: usize,
805
+ head_dim: usize,
806
+ scaling: f64,
807
+ }
808
+
809
+ impl XposMultiheadAttention {
810
+ fn new(vb: VarBuilder, embed_dim: usize, num_heads: usize) -> Result<Self> {
811
+ let head_dim = embed_dim / num_heads;
812
+ Ok(Self {
813
+ k_proj: load_linear(vb.pp("k_proj"), embed_dim, embed_dim)?,
814
+ v_proj: load_linear(vb.pp("v_proj"), embed_dim, embed_dim)?,
815
+ q_proj: load_linear(vb.pp("q_proj"), embed_dim, embed_dim)?,
816
+ out_proj: load_linear(vb.pp("out_proj"), embed_dim, embed_dim)?,
817
+ xpos: Xpos::new(vb.pp("xpos"), head_dim, embed_dim)?,
818
+ num_heads,
819
+ head_dim,
820
+ scaling: (head_dim as f64).powf(-0.5),
821
+ })
822
+ }
823
+
824
+ fn forward(
825
+ &self,
826
+ query: &Tensor,
827
+ key: &Tensor,
828
+ value: &Tensor,
829
+ key_padding_mask: Option<&Tensor>,
830
+ k_offset: usize,
831
+ q_offset: usize,
832
+ ) -> Result<Tensor> {
833
+ let (batch, tgt_len, embed_dim) = query.dims3()?;
834
+ let (_, src_len, _) = key.dims3()?;
835
+ anyhow::ensure!(
836
+ embed_dim == self.num_heads * self.head_dim,
837
+ "unexpected attention dim: {embed_dim}"
838
+ );
839
+
840
+ let q = self
841
+ .q_proj
842
+ .forward(query)?
843
+ .affine(self.scaling, 0.0)?
844
+ .reshape((batch, tgt_len, self.num_heads, self.head_dim))?
845
+ .transpose(1, 2)?
846
+ .reshape((batch * self.num_heads, tgt_len, self.head_dim))?;
847
+ let k = self
848
+ .k_proj
849
+ .forward(key)?
850
+ .reshape((batch, src_len, self.num_heads, self.head_dim))?
851
+ .transpose(1, 2)?
852
+ .reshape((batch * self.num_heads, src_len, self.head_dim))?;
853
+ let v = self
854
+ .v_proj
855
+ .forward(value)?
856
+ .reshape((batch, src_len, self.num_heads, self.head_dim))?
857
+ .transpose(1, 2)?
858
+ .reshape((batch * self.num_heads, src_len, self.head_dim))?;
859
+
860
+ let q = self.xpos.forward(&q, q_offset, false)?;
861
+ let k = self.xpos.forward(&k, k_offset, true)?;
862
+
863
+ let mut attn_weights = q.matmul(&k.transpose(1, 2)?)?;
864
+ if let Some(mask) = key_padding_mask {
865
+ let attn_weights_4d =
866
+ attn_weights.reshape((batch, self.num_heads, tgt_len, src_len))?;
867
+ let mask = mask
868
+ .reshape((batch, 1, 1, src_len))?
869
+ .broadcast_as(attn_weights_4d.shape().dims())?;
870
+ let neg_inf = Tensor::full(
871
+ f32::NEG_INFINITY,
872
+ attn_weights_4d.shape().dims(),
873
+ attn_weights_4d.device(),
874
+ )?;
875
+ attn_weights = mask.where_cond(&neg_inf, &attn_weights_4d)?.reshape((
876
+ batch * self.num_heads,
877
+ tgt_len,
878
+ src_len,
879
+ ))?;
880
+ }
881
+
882
+ let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
883
+ let attn = attn_weights
884
+ .matmul(&v)?
885
+ .reshape((batch, self.num_heads, tgt_len, self.head_dim))?
886
+ .transpose(1, 2)?
887
+ .reshape((batch, tgt_len, embed_dim))?;
888
+ Ok(self.out_proj.forward(&attn)?)
889
+ }
890
+ }
891
+
892
+ struct Xpos {
893
+ scale: Tensor,
894
+ scale_base: usize,
895
+ }
896
+
897
+ impl Xpos {
898
+ fn new(vb: VarBuilder, head_dim: usize, scale_base: usize) -> Result<Self> {
899
+ let scale = vb.get(head_dim / 2, "scale")?;
900
+ Ok(Self { scale, scale_base })
901
+ }
902
+
903
+ fn forward(&self, xs: &Tensor, offset: usize, downscale: bool) -> Result<Tensor> {
904
+ let (_, length, head_dim) = xs.dims3()?;
905
+ if length == 0 {
906
+ return Ok(xs.clone());
907
+ }
908
+ let half_dim = head_dim / 2;
909
+ let min_pos = -((length + offset) as i64 / 2);
910
+ let max_pos = length as i64 + offset as i64 + min_pos;
911
+ let exponents = Tensor::arange(min_pos as f32, max_pos as f32, xs.device())?
912
+ .affine(1.0 / self.scale_base as f64, 0.0)?
913
+ .reshape(((max_pos - min_pos) as usize, 1))?;
914
+ let mut scale = self.scale.broadcast_pow(&exponents)?;
915
+ let (mut sin, mut cos) = fixed_pos_embedding(scale.dims2()?.0, half_dim, xs.device())?;
916
+
917
+ if scale.dim(0)? > length {
918
+ let start = scale.dim(0)? - length;
919
+ scale = scale.narrow(0, start, length)?;
920
+ sin = sin.narrow(0, start, length)?;
921
+ cos = cos.narrow(0, start, length)?;
922
+ }
923
+ if downscale {
924
+ scale = scale.recip()?;
925
+ }
926
+ apply_rotary_pos_emb(xs, &sin, &cos, &scale)
927
+ }
928
+ }
929
+
930
+ fn fixed_pos_embedding(seq_len: usize, dim: usize, device: &Device) -> Result<(Tensor, Tensor)> {
931
+ let positions = Tensor::arange(0f32, seq_len as f32, device)?.reshape((seq_len, 1))?;
932
+ let inv_freq = Tensor::arange(0f32, dim as f32, device)?
933
+ .affine(-(10000f32.ln() as f64) / dim as f64, 0.0)?
934
+ .exp()?
935
+ .reshape((1, dim))?;
936
+ let sinusoid = positions.broadcast_mul(&inv_freq)?;
937
+ Ok((sinusoid.sin()?, sinusoid.cos()?))
938
+ }
939
+
940
+ fn duplicate_interleave(xs: &Tensor) -> Result<Tensor> {
941
+ let (rows, cols) = xs.dims2()?;
942
+ Ok(xs
943
+ .reshape((rows * cols, 1))?
944
+ .repeat((1, 2))?
945
+ .reshape((rows, cols * 2))?)
946
+ }
947
+
948
+ fn rotate_every_two(xs: &Tensor) -> Result<Tensor> {
949
+ let head_dim = xs.dim(D::Minus1)?;
950
+ let even = Tensor::arange_step(0u32, head_dim as u32, 2u32, xs.device())?;
951
+ let odd = Tensor::arange_step(1u32, head_dim as u32, 2u32, xs.device())?;
952
+ let x1 = xs.index_select(&even, D::Minus1)?;
953
+ let x2 = xs.index_select(&odd, D::Minus1)?;
954
+ Ok(Tensor::stack(&[&x2.neg()?, &x1], D::Minus1)?.flatten_from(D::Minus2)?)
955
+ }
956
+
957
+ fn apply_rotary_pos_emb(xs: &Tensor, sin: &Tensor, cos: &Tensor, scale: &Tensor) -> Result<Tensor> {
958
+ let sin = duplicate_interleave(&sin.broadcast_mul(scale)?)?;
959
+ let cos = duplicate_interleave(&cos.broadcast_mul(scale)?)?;
960
+ let sin = sin.reshape((1, sin.dim(0)?, sin.dim(1)?))?;
961
+ let cos = cos.reshape((1, cos.dim(0)?, cos.dim(1)?))?;
962
+ Ok(xs
963
+ .broadcast_mul(&cos)?
964
+ .broadcast_add(&rotate_every_two(xs)?.broadcast_mul(&sin)?)?)
965
+ }
966
+
967
+ #[cfg(test)]
968
+ mod tests {
969
+ use candle_core::{Device, Tensor, test_utils};
970
+
971
+ use super::{duplicate_interleave, fixed_pos_embedding, rotate_every_two, topk_last_dim};
972
+
973
+ #[test]
974
+ fn duplicate_interleave_matches_python_behavior() -> anyhow::Result<()> {
975
+ let xs = Tensor::from_vec(vec![1f32, 2., 3., 4.], (2, 2), &Device::Cpu)?;
976
+ let ys = duplicate_interleave(&xs)?;
977
+ assert_eq!(
978
+ ys.to_vec2::<f32>()?,
979
+ vec![vec![1.0, 1.0, 2.0, 2.0], vec![3.0, 3.0, 4.0, 4.0]]
980
+ );
981
+ Ok(())
982
+ }
983
+
984
+ #[test]
985
+ fn rotate_every_two_matches_reference() -> anyhow::Result<()> {
986
+ let xs = Tensor::from_vec(vec![1f32, 2., 3., 4.], (1, 1, 4), &Device::Cpu)?;
987
+ let ys = rotate_every_two(&xs)?;
988
+ assert_eq!(ys.to_vec3::<f32>()?, vec![vec![vec![-2.0, 1.0, -4.0, 3.0]]]);
989
+ Ok(())
990
+ }
991
+
992
+ #[test]
993
+ fn fixed_pos_embedding_shape_and_values_are_stable() -> anyhow::Result<()> {
994
+ let (sin, cos) = fixed_pos_embedding(3, 2, &Device::Cpu)?;
995
+ assert_eq!(
996
+ test_utils::to_vec2_round(&sin, 4)?,
997
+ &[[0.0, 0.0], [0.8415, 0.01], [0.9093, 0.02]]
998
+ );
999
+ assert_eq!(
1000
+ test_utils::to_vec2_round(&cos, 4)?,
1001
+ &[[1.0, 1.0], [0.5403, 1.0], [-0.4161, 0.9998]]
1002
+ );
1003
+ Ok(())
1004
+ }
1005
+
1006
+ #[test]
1007
+ fn topk_last_dim_returns_descending_scores_and_indices() -> anyhow::Result<()> {
1008
+ let xs = Tensor::from_vec(vec![0.1f32, 0.9, 0.3, 0.7], (1, 4), &Device::Cpu)?;
1009
+ let (values, indices) = topk_last_dim(&xs, 3)?;
1010
+ assert_eq!(values, vec![vec![0.9, 0.7, 0.3]]);
1011
+ assert_eq!(indices, vec![vec![1, 3, 2]]);
1012
+ Ok(())
1013
+ }
1014
+ }
koharu-ml/tests/ocr.rs CHANGED
@@ -1,21 +1,34 @@
1
  use std::path::Path;
2
 
3
- use koharu_ml::manga_ocr::MangaOcr;
 
4
 
5
  #[tokio::test]
6
  #[ignore]
7
- async fn manga_ocr_reads_dialog_image() -> anyhow::Result<()> {
8
  let fixtures = Path::new(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures");
9
- let image = image::open(fixtures.join("dialog.jpg"))?;
 
 
 
 
 
 
 
10
 
11
- let ocr = MangaOcr::load(false).await?;
12
- let results = ocr.inference(&[image])?;
13
 
14
  assert_eq!(results.len(), 1);
15
  assert!(
16
- !results[0].trim().is_empty(),
17
  "OCR result should contain text"
18
  );
 
 
 
 
 
19
 
20
  Ok(())
21
  }
 
1
  use std::path::Path;
2
 
3
+ use koharu_ml::mit48px_ocr::Mit48pxOcr;
4
+ use koharu_types::TextBlock;
5
 
6
  #[tokio::test]
7
  #[ignore]
8
+ async fn mit48px_reads_dialog_image_via_default_block_path() -> anyhow::Result<()> {
9
  let fixtures = Path::new(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures");
10
+ let image = image::open(fixtures.join("1.jpg"))?.crop_imm(66, 26, 270, 48);
11
+ let block = TextBlock {
12
+ x: 0.0,
13
+ y: 0.0,
14
+ width: image.width() as f32,
15
+ height: image.height() as f32,
16
+ ..Default::default()
17
+ };
18
 
19
+ let ocr = Mit48pxOcr::load(false).await?;
20
+ let results = ocr.inference_text_blocks(&image, &[block])?;
21
 
22
  assert_eq!(results.len(), 1);
23
  assert!(
24
+ !results[0].text.trim().is_empty(),
25
  "OCR result should contain text"
26
  );
27
+ assert!(
28
+ results[0].text.contains("対策"),
29
+ "unexpected OCR output: {}",
30
+ results[0].text
31
+ );
32
 
33
  Ok(())
34
  }
scripts/convert_mit48px.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Convert mit48px OCR weights to safetensors for Candle.
4
+ """
5
+
6
+ import argparse
7
+ import json
8
+ import shutil
9
+ from pathlib import Path
10
+
11
+ from huggingface_hub import hf_hub_download
12
+ from safetensors.torch import save_file
13
+ import torch
14
+
15
+
16
+ MODEL_REPO = "zyddnys/manga-image-translator"
17
+ MODEL_FILENAME = "ocr_ar_48px.ckpt"
18
+ DICT_FILENAME = "alphabet-all-v7.txt"
19
+
20
+
21
+ def parse_args() -> argparse.Namespace:
22
+ default_output = Path.home() / ".cache" / "Koharu" / "models" / "mit48px-ocr"
23
+ parser = argparse.ArgumentParser(description="Convert mit48px OCR checkpoint to safetensors.")
24
+ parser.add_argument(
25
+ "--checkpoint",
26
+ type=Path,
27
+ default=None,
28
+ help="Optional local checkpoint path. Defaults to downloading ocr_ar_48px.ckpt.",
29
+ )
30
+ parser.add_argument(
31
+ "--dictionary",
32
+ type=Path,
33
+ default=None,
34
+ help="Optional local dictionary path. Defaults to downloading alphabet-all-v7.txt.",
35
+ )
36
+ parser.add_argument(
37
+ "-o",
38
+ "--output-dir",
39
+ type=Path,
40
+ default=default_output,
41
+ help=f"Output directory (default: {default_output})",
42
+ )
43
+ return parser.parse_args()
44
+
45
+
46
+ def load_state_dict(checkpoint_path: Path) -> dict[str, torch.Tensor]:
47
+ state = torch.load(checkpoint_path, map_location="cpu")
48
+ if isinstance(state, dict) and "state_dict" in state and isinstance(state["state_dict"], dict):
49
+ state = state["state_dict"]
50
+ if not isinstance(state, dict):
51
+ raise RuntimeError("Unexpected checkpoint format")
52
+ tensor_map = {}
53
+ for key, value in state.items():
54
+ if not isinstance(value, torch.Tensor):
55
+ raise RuntimeError(f"Unexpected non-tensor entry for key {key!r}")
56
+ tensor_map[key] = value.detach().cpu().contiguous().clone()
57
+ return tensor_map
58
+
59
+
60
+ def main() -> None:
61
+ args = parse_args()
62
+ args.output_dir.mkdir(parents=True, exist_ok=True)
63
+
64
+ checkpoint_path = args.checkpoint or Path(
65
+ hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
66
+ )
67
+ dictionary_path = args.dictionary or Path(
68
+ hf_hub_download(repo_id=MODEL_REPO, filename=DICT_FILENAME)
69
+ )
70
+
71
+ state_dict = load_state_dict(checkpoint_path)
72
+ save_file(state_dict, str(args.output_dir / "model.safetensors"))
73
+ shutil.copyfile(dictionary_path, args.output_dir / DICT_FILENAME)
74
+
75
+ config = {
76
+ "text_height": 48,
77
+ "max_width": 8100,
78
+ "embd_dim": 320,
79
+ "num_heads": 4,
80
+ "encoder_layers": 4,
81
+ "decoder_layers": 5,
82
+ "beam_size_default": 5,
83
+ "max_seq_length_default": 255,
84
+ "pad_token_id": 0,
85
+ "bos_token_id": 1,
86
+ "eos_token_id": 2,
87
+ "space_token": "<SP>",
88
+ "dictionary_file": DICT_FILENAME,
89
+ }
90
+ with open(args.output_dir / "config.json", "w", encoding="utf-8") as fp:
91
+ json.dump(config, fp, ensure_ascii=False, indent=2)
92
+ fp.write("\n")
93
+
94
+ print(f"Saved {len(state_dict)} tensors to {args.output_dir / 'model.safetensors'}")
95
+ print(f"Saved dictionary to {args.output_dir / DICT_FILENAME}")
96
+ print(f"Saved config to {args.output_dir / 'config.json'}")
97
+
98
+
99
+ if __name__ == "__main__":
100
+ main()