Johnny-Z commited on
Commit
1aee7bf
·
verified ·
1 Parent(s): d52be25

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -56
app.py CHANGED
@@ -10,41 +10,44 @@ from huggingface_hub import login, snapshot_download
10
  TITLE = "Danbooru Tagger"
11
  DESCRIPTION = """
12
  ## Dataset
13
- - Source: Cleaned Danbooru
14
-
15
- ## Metrics
16
  - Validation Split: 10% of Dataset
17
- - Validation Results:
 
18
 
19
  ### General
 
20
  | Metric | Value |
21
  |-----------------|-------------|
22
- | Macro F1 | 0.4678 |
23
- | Macro Precision | 0.4605 |
24
- | Macro Recall | 0.5229 |
25
- | Micro F1 | 0.6661 |
26
- | Micro Precision | 0.6049 |
27
- | Micro Recall | 0.7411 |
28
 
29
  ### Character
 
30
  | Metric | Value |
31
  |-----------------|-------------|
32
- | Macro F1 | 0.8925 |
33
- | Macro Precision | 0.9099 |
34
- | Macro Recall | 0.8935 |
35
- | Micro F1 | 0.9232 |
36
- | Micro Precision | 0.9264 |
37
- | Micro Recall | 0.9199 |
38
 
39
  ### Artist
 
40
  | Metric | Value |
41
  |-----------------|-------------|
42
- | Macro F1 | 0.7904 |
43
- | Macro Precision | 0.8286 |
44
- | Macro Recall | 0.7904 |
45
- | Micro F1 | 0.5989 |
46
- | Micro Precision | 0.5975 |
47
- | Micro Recall | 0.6004 |
48
  """
49
 
50
  kaomojis = [
@@ -78,7 +81,7 @@ if hf_token:
78
  else:
79
  raise ValueError("environment variable HF_TOKEN not found.")
80
 
81
- repo_id = "Johnny-Z/vit-e4"
82
  repo_dir = snapshot_download(repo_id)
83
  model = AutoModel.from_pretrained(repo_id, dtype=dtype, trust_remote_code=True, device_map=device)
84
 
@@ -127,25 +130,6 @@ class MLP(nn.Module):
127
  x = self.sigmoid(x)
128
  return x
129
 
130
- class MLP_Retrieval(nn.Module):
131
- def __init__(self, input_size, class_num):
132
- super().__init__()
133
- self.mlp_layer0 = nn.Sequential(
134
- nn.Linear(input_size, input_size // 2),
135
- nn.SiLU()
136
- )
137
- self.mlp_layer1 = nn.Linear(input_size // 2, class_num)
138
-
139
- def forward(self, x):
140
- x = self.mlp_layer0(x)
141
- x = self.mlp_layer1(x)
142
- x1, x2 = x[:, :15], x[:, 15:]
143
- x1 = torch.softmax(x1, dim=1)
144
- x2 = torch.softmax(x2, dim=1)
145
- x = torch.cat([x1, x2], dim=1)
146
-
147
- return x
148
-
149
  with open(os.path.join(repo_dir, 'general_tag_dict.json'), 'r', encoding='utf-8') as f:
150
  general_dict = json.load(f)
151
 
@@ -171,25 +155,21 @@ model_map = MultiheadAttentionPoolingHead(2048)
171
  model_map.load_state_dict(torch.load(os.path.join(repo_dir, "map_head.pth"), map_location=device, weights_only=True))
172
  model_map.to(device).to(dtype).eval()
173
 
174
- general_class = 9775
175
  mlp_general = MLP(2048, general_class)
176
  mlp_general.load_state_dict(torch.load(os.path.join(repo_dir, "cls_predictor_general.pth"), map_location=device, weights_only=True))
177
  mlp_general.to(device).to(dtype).eval()
178
 
179
- character_class = 7568
180
  mlp_character = MLP(2048, character_class)
181
  mlp_character.load_state_dict(torch.load(os.path.join(repo_dir, "cls_predictor_character.pth"), map_location=device, weights_only=True))
182
  mlp_character.to(device).to(dtype).eval()
183
 
184
- artist_class = 13957
185
  mlp_artist = MLP(2048, artist_class)
186
  mlp_artist.load_state_dict(torch.load(os.path.join(repo_dir, "cls_predictor_artist.pth"), map_location=device, weights_only=True))
187
  mlp_artist.to(device).to(dtype).eval()
188
 
189
- mlp_artist_retrieval = MLP_Retrieval(2048, artist_class)
190
- mlp_artist_retrieval.load_state_dict(torch.load(os.path.join(repo_dir, "cls_predictor_artist_retrieval.pth"), map_location=device, weights_only=True))
191
- mlp_artist_retrieval.to(device).to(dtype).eval()
192
-
193
  def prediction_to_tag(prediction, tag_dict, class_num):
194
  prediction = prediction.view(class_num)
195
  predicted_ids = (prediction >= 0.2).nonzero(as_tuple=True)[0].cpu().numpy() + 1
@@ -273,17 +253,10 @@ def process_image(image):
273
  character_ = prediction_to_tag(character_prediction, character_dict, character_class)
274
  character_tags = character_[1]
275
 
276
- """
277
  artist_prediction = mlp_artist(embedding)
278
  artist_ = prediction_to_tag(artist_prediction, artist_dict, artist_class)
279
  artist_tags = artist_[2]
280
  date = artist_[3]
281
- """
282
-
283
- artist_retrieval_prediction = mlp_artist_retrieval(embedding)
284
- artist_retrieval_ = prediction_to_retrieval(artist_retrieval_prediction, artist_dict, artist_class, 10)
285
- artist_tags = artist_retrieval_[0]
286
- date = artist_retrieval_[1]
287
 
288
  combined_tags = {**general_tags}
289
 
 
10
  TITLE = "Danbooru Tagger"
11
  DESCRIPTION = """
12
  ## Dataset
13
+ - Source: Danbooru
14
+ - Cutoff Date: 2025-11-27
 
15
  - Validation Split: 10% of Dataset
16
+
17
+ ## Validation Results
18
 
19
  ### General
20
+ Tags Count: 11046
21
  | Metric | Value |
22
  |-----------------|-------------|
23
+ | Macro F1 | 0.4439 |
24
+ | Macro Precision | 0.4168 |
25
+ | Macro Recall | 0.4964 |
26
+ | Micro F1 | 0.6595 |
27
+ | Micro Precision | 0.5982 |
28
+ | Micro Recall | 0.7349 |
29
 
30
  ### Character
31
+ Tags Count: 9148
32
  | Metric | Value |
33
  |-----------------|-------------|
34
+ | Macro F1 | 0.8646 |
35
+ | Macro Precision | 0.8897 |
36
+ | Macro Recall | 0.8492 |
37
+ | Micro F1 | 0.9092 |
38
+ | Micro Precision | 0.9195 |
39
+ | Micro Recall | 0.8991 |
40
 
41
  ### Artist
42
+ Tags Count: 17171
43
  | Metric | Value |
44
  |-----------------|-------------|
45
+ | Macro F1 | 0.8008 |
46
+ | Macro Precision | 0.8669 |
47
+ | Macro Recall | 0.7641 |
48
+ | Micro F1 | 0.8596 |
49
+ | Micro Precision | 0.8948 |
50
+ | Micro Recall | 0.8271 |
51
  """
52
 
53
  kaomojis = [
 
81
  else:
82
  raise ValueError("environment variable HF_TOKEN not found.")
83
 
84
+ repo_id = "Johnny-Z/danbooru_vfm"
85
  repo_dir = snapshot_download(repo_id)
86
  model = AutoModel.from_pretrained(repo_id, dtype=dtype, trust_remote_code=True, device_map=device)
87
 
 
130
  x = self.sigmoid(x)
131
  return x
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  with open(os.path.join(repo_dir, 'general_tag_dict.json'), 'r', encoding='utf-8') as f:
134
  general_dict = json.load(f)
135
 
 
155
  model_map.load_state_dict(torch.load(os.path.join(repo_dir, "map_head.pth"), map_location=device, weights_only=True))
156
  model_map.to(device).to(dtype).eval()
157
 
158
+ general_class = 11046
159
  mlp_general = MLP(2048, general_class)
160
  mlp_general.load_state_dict(torch.load(os.path.join(repo_dir, "cls_predictor_general.pth"), map_location=device, weights_only=True))
161
  mlp_general.to(device).to(dtype).eval()
162
 
163
+ character_class = 9148
164
  mlp_character = MLP(2048, character_class)
165
  mlp_character.load_state_dict(torch.load(os.path.join(repo_dir, "cls_predictor_character.pth"), map_location=device, weights_only=True))
166
  mlp_character.to(device).to(dtype).eval()
167
 
168
+ artist_class = 17171
169
  mlp_artist = MLP(2048, artist_class)
170
  mlp_artist.load_state_dict(torch.load(os.path.join(repo_dir, "cls_predictor_artist.pth"), map_location=device, weights_only=True))
171
  mlp_artist.to(device).to(dtype).eval()
172
 
 
 
 
 
173
  def prediction_to_tag(prediction, tag_dict, class_num):
174
  prediction = prediction.view(class_num)
175
  predicted_ids = (prediction >= 0.2).nonzero(as_tuple=True)[0].cpu().numpy() + 1
 
253
  character_ = prediction_to_tag(character_prediction, character_dict, character_class)
254
  character_tags = character_[1]
255
 
 
256
  artist_prediction = mlp_artist(embedding)
257
  artist_ = prediction_to_tag(artist_prediction, artist_dict, artist_class)
258
  artist_tags = artist_[2]
259
  date = artist_[3]
 
 
 
 
 
 
260
 
261
  combined_tags = {**general_tags}
262