RichardScottOZ commited on
Commit
c047b44
·
verified ·
1 Parent(s): 0de2bbd

update model card

Browse files

Add basic model card.

Files changed (1) hide show
  1. README.md +147 -0
README.md CHANGED
@@ -4,6 +4,153 @@ language:
4
  - en
5
  ---
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  ## Example of Cosmov4 predictions
8
 
9
  ```python
 
4
  - en
5
  ---
6
 
7
+ ---
8
+ language: en
9
+ tags:
10
+ - vision
11
+ - text
12
+ - multimodal
13
+ - comics
14
+ - page-classification
15
+ - bert
16
+ license: mit
17
+ ---
18
+
19
+ # CoSMo v4 (Comic Stream Modeling - Page Classifier)
20
+
21
+ CoSMo v4 is a highly specialized multimodal classifier designed to categorize pages within a comic book archive into distinct structural classes (e.g., *story, cover, advertisement, credits*).
22
+
23
+ It represents "Stage 2" of the [Comic Analysis Framework v2.0](https://github.com/RichardScottOZ/Comic-Analysis), acting as the critical gatekeeper that filters raw comic archives down to pure narrative content for downstream sequence modeling.
24
+
25
+ This v4 iteration introduces the **BookBERTMultimodal2** architecture, which replaces standard Convolutional feature extractors with modern Vision-Language models, achieving state-of-the-art accuracy on unstructured comic data.
26
+
27
+ ## Model Architecture
28
+
29
+ CoSMo v4 is based on the `BookBERTMultimodal2` class. It treats a comic book as a "sequence" of pages and uses a Transformer encoder to understand the context of a page based on its position in the book.
30
+
31
+ 1. **Visual Features (`1152-dim`):** Extracted using **SigLIP** (`google/siglip-so400m-patch14-384`).
32
+ 2. **Text Features (`1024-dim`):** Extracted from OCR text using **Qwen-Embedding** (`Qwen/Qwen3-Embedding-0.6B`).
33
+ 3. **Projections:** Deep MLP projection layers `(Dim -> 3840 -> 1920 -> 768)` align both visual and text features into a common `768-dim` space.
34
+ 4. **Contextual Encoding:** A 4-layer, 4-head **BERT Encoder** (`Transformers.BertModel`) processes the combined features across the entire length of the comic book, allowing the model to understand that an advertisement usually follows a story page, or credits appear at the end.
35
+ 5. **Classification Head:** A deep sequential classifier maps the contextualized `768-dim` token back to one of **9 distinct classes**.
36
+
37
+ ## Output Classes
38
+ The model predicts one of 9 labels for every page:
39
+ 1. `advertisement`
40
+ 2. `cover`
41
+ 3. `story` (The primary narrative content)
42
+ 4. `textstory`
43
+ 5. `first-page`
44
+ 6. `credits`
45
+ 7. `art` (Splash pages, pin-ups)
46
+ 8. `text` (Editorial text)
47
+ 9. `back_cover`
48
+
49
+ ## Usage
50
+
51
+ Because CoSMo v4 requires pre-computed SigLIP and Qwen embeddings, inference is typically a two-step process. The complete codebase for embedding generation and Zarr-based inference is available in the [Comic Analysis GitHub Repository](https://github.com/RichardScottOZ/Comic-Analysis) under `src/cosmo/`.
52
+
53
+ ### Quick Start Inference Snippet
54
+
55
+ If you already have your visual (`1152-d`) and text (`1024-d`) embeddings for a sequence of pages, you can run inference like this:
56
+
57
+ ```python
58
+ import torch
59
+ import torch.nn as nn
60
+ from transformers import BertConfig, BertModel
61
+
62
+ # 1. Define Architecture (Must match exactly)
63
+ class BookBERT(nn.Module):
64
+ def __init__(self, bert_input=768, num_classes=9, hidden_dim=512, dropout_p=0.0):
65
+ super().__init__()
66
+ config = BertConfig(
67
+ hidden_size=bert_input, num_hidden_layers=4, num_attention_heads=4,
68
+ intermediate_size=bert_input * 4, max_position_embeddings=1024
69
+ )
70
+ self.bert_encoder = BertModel(config)
71
+ self.classifier = nn.Sequential(
72
+ nn.Linear(bert_input, hidden_dim),
73
+ nn.Linear(hidden_dim, hidden_dim // 2),
74
+ nn.LayerNorm(hidden_dim // 2),
75
+ nn.GELU(),
76
+ nn.Dropout(dropout_p),
77
+ nn.Linear(hidden_dim // 2, hidden_dim // 4),
78
+ nn.LayerNorm(hidden_dim // 4),
79
+ nn.GELU(),
80
+ nn.Dropout(dropout_p),
81
+ nn.Linear(hidden_dim // 4, num_classes)
82
+ )
83
+
84
+ class BookBERTMultimodal2(BookBERT):
85
+ def __init__(self, textual_dim=1024, visual_dim=1152, bert_dim=768, classes=9):
86
+ super().__init__(bert_input=bert_dim, num_classes=classes, hidden_dim=512, dropout_p=0.0)
87
+
88
+ sz1_v = (visual_dim + bert_dim) * 2
89
+ self.visual_projection = nn.Sequential(
90
+ nn.Linear(visual_dim, sz1_v), nn.LayerNorm(sz1_v), nn.GELU(), nn.Dropout(0.0),
91
+ nn.Linear(sz1_v, sz1_v//2), nn.LayerNorm(sz1_v//2), nn.GELU(), nn.Dropout(0.0),
92
+ nn.Linear(sz1_v//2, bert_dim)
93
+ )
94
+
95
+ sz1_t = (textual_dim + bert_dim) * 2
96
+ self.textual_projection = nn.Sequential(
97
+ nn.Linear(textual_dim, sz1_t), nn.LayerNorm(sz1_t), nn.GELU(), nn.Dropout(0.0),
98
+ nn.Linear(sz1_t, sz1_t//2), nn.LayerNorm(sz1_t//2), nn.GELU(), nn.Dropout(0.0),
99
+ nn.Linear(sz1_t//2, bert_dim)
100
+ )
101
+ self.norm = nn.LayerNorm(bert_dim)
102
+
103
+ def forward(self, textual_features, visual_features):
104
+ batch_size, seq_len, _ = textual_features.shape
105
+ mask = torch.ones((batch_size, seq_len), device=textual_features.device)
106
+
107
+ t_norm = self.norm(self.textual_projection(textual_features))
108
+ v_norm = self.norm(self.visual_projection(visual_features))
109
+
110
+ combined = torch.stack([t_norm, v_norm], dim=2).view(batch_size, seq_len * 2, -1)
111
+ exp_mask = mask.unsqueeze(2).expand(-1, -1, 2).reshape(batch_size, seq_len * 2)
112
+
113
+ bert_out = self.bert_encoder(inputs_embeds=combined, attention_mask=exp_mask)
114
+ reshaped = bert_out.last_hidden_state.view(batch_size, seq_len, 2, -1)
115
+ return self.classifier(reshaped[:, :, -1, :])
116
+
117
+ # 2. Load Model
118
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
119
+ model = BookBERTMultimodal2().to(device)
120
+
121
+ state_dict = torch.hub.load_state_dict_from_url(
122
+ "https://huggingface.co/RichardScottOZ/cosmo-v4/resolve/main/best_Multimodal_MultiToken_v4.pt",
123
+ map_location=device
124
+ )
125
+ if 'model_state_dict' in state_dict:
126
+ state_dict = state_dict['model_state_dict']
127
+ model.load_state_dict(state_dict, strict=True)
128
+ model.eval()
129
+
130
+ # 3. Inference (Example: 1 comic book containing 24 pages)
131
+ # visual_embeddings shape: (1, 24, 1152) -> From SigLIP
132
+ # text_embeddings shape: (1, 24, 1024) -> From Qwen
133
+ visual_embs = torch.randn(1, 24, 1152).to(device)
134
+ text_embs = torch.randn(1, 24, 1024).to(device)
135
+
136
+ with torch.inference_mode():
137
+ logits = model(text_embs, visual_embs)
138
+ predictions = torch.argmax(logits, dim=-1).squeeze(0)
139
+
140
+ class_names = ["advertisement", "cover", "story", "textstory", "first-page", "credits", "art", "text", "back_cover"]
141
+ for page_num, pred_idx in enumerate(predictions):
142
+ print(f"Page {page_num}: {class_names[pred_idx]}")
143
+ ```
144
+
145
+ ## Intended Use
146
+ This model is designed to process entire comic books/issues as a single sequence. Due to the positional embeddings in the BERT encoder, feeding it pages completely out of order or feeding it a single page at a time without context will degrade performance.
147
+
148
+ *Note: The model has a hard limit of `1024` tokens, equating to `512` pages per forward pass. For massive omnibuses, chunking is required.*
149
+
150
+ ## Citation
151
+ If you use this model or the framework, please reference the [Comic Analysis GitHub Repository](https://github.com/RichardScottOZ/Comic-Analysis).
152
+
153
+
154
  ## Example of Cosmov4 predictions
155
 
156
  ```python