Ole1 commited on
Commit
a0cef3f
·
verified ·
1 Parent(s): b0aaf79

Update Run_gui.py

Browse files
Files changed (1) hide show
  1. Run_gui.py +1548 -36
Run_gui.py CHANGED
@@ -1,3 +1,1529 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import sys
2
  import os
3
  import torch
@@ -114,7 +1640,7 @@ EXTRA_OPTIONS_LIST = [
114
 
115
  CAPTION_LENGTH_CHOICES = (
116
  ["any", "very short", "short", "medium-length", "long", "very long"]
117
- + [str(i) for i in range(20, 261, 10)]
118
  )
119
 
120
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
@@ -183,7 +1709,7 @@ class ImageAdapter(nn.Module):
183
 
184
  if self.pos_emb is not None:
185
  assert x.shape[-2:] == self.pos_emb.shape
186
- x = x + self.pos_emb
187
 
188
  x = self.linear1(x)
189
  x = self.activation(x)
@@ -307,7 +1833,7 @@ def generate_caption(
307
  else: raise ValueError(f"Invalid caption length: {length}")
308
 
309
  prompt_str = CAPTION_TYPE_MAP[caption_type][map_idx]
310
- if len(extra_options) > 0: prompt_str += " " + " ".join(extra_options)
311
  prompt_str = prompt_str.format(name=name_input, length=caption_length, word_count=caption_length)
312
 
313
  print(f"Prompt: {prompt_str}")
@@ -335,7 +1861,7 @@ def generate_caption(
335
  if hasattr(tokenizer, "apply_chat_template"):
336
  convo_string = tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
337
  else:
338
- convo_string = ("<|system|>\n" + convo[0]["content"] + "\n<|end|>\n<|user|>\n" + convo[1]["content"] + "\n<|end|>\n")
339
  assert isinstance(convo_string, str)
340
 
341
  convo_tokens = tokenizer.encode(convo_string, return_tensors="pt", add_special_tokens=False, truncation=False).to(device)
@@ -347,7 +1873,7 @@ def generate_caption(
347
  end_token_id = tokenizer.convert_tokens_to_ids("<|end|>")
348
  if end_token_id is None: raise ValueError("Tokenizer missing '<|end|>' token.")
349
  end_token_indices = (convo_tokens == end_token_id).nonzero(as_tuple=True)[0].tolist()
350
- preamble_len = end_token_indices[0] + 1 if len(end_token_indices) >= 1 else 0
351
 
352
  convo_embeds = text_model.model.embed_tokens(convo_tokens.unsqueeze(0).to(device))
353
  input_embeds = torch.cat([
@@ -374,7 +1900,7 @@ def generate_caption(
374
  generate_ids = generate_ids[:, input_ids.shape[1]:]
375
  caption = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]
376
  caption = caption.strip()
377
- caption = re.sub(r'\s+', ' ', caption)
378
 
379
  return prompt_str, caption
380
 
@@ -398,17 +1924,15 @@ class CaptionApp(QMainWindow):
398
  self.selected_image_path = None
399
  self.image_files = []
400
 
401
- self.dark_mode = False # Midlertidigt
402
 
403
  self.central_widget = QWidget()
404
  self.setCentralWidget(self.central_widget)
405
  self.main_layout = QHBoxLayout(self.central_widget)
406
 
407
- self.initUI()
408
- self.dark_mode = True # Aktiver dark mode efter UI er sat op
409
- print(self.dark_mode)
410
- self.apply_theme() # Tving temaet til at skifte
411
  self.update_button_states()
 
412
 
413
 
414
  def initUI(self):
@@ -481,7 +2005,6 @@ class CaptionApp(QMainWindow):
481
  # Checkpoint Path
482
  ckpt_layout = QHBoxLayout()
483
  self.checkpoint_path_line = QLineEdit()
484
- self.checkpoint_path_line.setText("cgrkzexw-599808")
485
  self.checkpoint_path_line.setToolTip("Path to the folder containing model files (clip_model.pt, etc.).")
486
  ckpt_layout.addWidget(QLabel("Checkpoint Path:"))
487
  ckpt_layout.addWidget(self.checkpoint_path_line)
@@ -709,7 +2232,7 @@ EXTRA_OPTIONS_LIST = [
709
 
710
  CAPTION_LENGTH_CHOICES = (
711
  ["any", "very short", "short", "medium-length", "long", "very long"]
712
- + [str(i) for i in range(20, 261, 10)]
713
  )
714
 
715
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
@@ -778,7 +2301,7 @@ class ImageAdapter(nn.Module):
778
 
779
  if self.pos_emb is not None:
780
  assert x.shape[-2:] == self.pos_emb.shape
781
- x = x + self.pos_emb
782
 
783
  x = self.linear1(x)
784
  x = self.activation(x)
@@ -902,7 +2425,7 @@ def generate_caption(
902
  else: raise ValueError(f"Invalid caption length: {length}")
903
 
904
  prompt_str = CAPTION_TYPE_MAP[caption_type][map_idx]
905
- if len(extra_options) > 0: prompt_str += " " + " ".join(extra_options)
906
  prompt_str = prompt_str.format(name=name_input, length=caption_length, word_count=caption_length)
907
 
908
  print(f"Prompt: {prompt_str}")
@@ -930,7 +2453,7 @@ def generate_caption(
930
  if hasattr(tokenizer, "apply_chat_template"):
931
  convo_string = tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
932
  else:
933
- convo_string = ("<|system|>\n" + convo[0]["content"] + "\n<|end|>\n<|user|>\n" + convo[1]["content"] + "\n<|end|>\n")
934
  assert isinstance(convo_string, str)
935
 
936
  convo_tokens = tokenizer.encode(convo_string, return_tensors="pt", add_special_tokens=False, truncation=False).to(device)
@@ -942,7 +2465,7 @@ def generate_caption(
942
  end_token_id = tokenizer.convert_tokens_to_ids("<|end|>")
943
  if end_token_id is None: raise ValueError("Tokenizer missing '<|end|>' token.")
944
  end_token_indices = (convo_tokens == end_token_id).nonzero(as_tuple=True)[0].tolist()
945
- preamble_len = end_token_indices[0] + 1 if len(end_token_indices) >= 1 else 0
946
 
947
  convo_embeds = text_model.model.embed_tokens(convo_tokens.unsqueeze(0).to(device))
948
  input_embeds = torch.cat([
@@ -969,7 +2492,7 @@ def generate_caption(
969
  generate_ids = generate_ids[:, input_ids.shape[1]:]
970
  caption = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]
971
  caption = caption.strip()
972
- caption = re.sub(r'\s+', ' ', caption)
973
 
974
  return prompt_str, caption
975
 
@@ -1000,18 +2523,6 @@ class CaptionApp(QMainWindow):
1000
  self.main_layout = QHBoxLayout(self.central_widget)
1001
 
1002
  self.initUI() # Call initUI
1003
-
1004
- # Attempt to auto-load models at startup
1005
-
1006
- if Path("cgrkzexw-599808").exists():
1007
- try:
1008
- (self.clip_processor, self.clip_model, self.tokenizer, self.text_model, self.image_adapter) = load_models(Path("cgrkzexw-599808"), status_callback=self.show_status)
1009
- self.models_loaded = True
1010
- self.show_status("Models loaded at startup.", 5000)
1011
- except Exception as e:
1012
- print("Auto-load failed:", e)
1013
- self.models_loaded = False
1014
-
1015
  self.update_button_states()
1016
  self.apply_theme()
1017
 
@@ -1235,7 +2746,7 @@ class CaptionApp(QMainWindow):
1235
  placeholder_style = "QTextEdit { placeholderTextColor: gray; } QLineEdit { placeholderTextColor: gray; }"
1236
  current_style = self.styleSheet()
1237
  if self.dark_mode:
1238
- if "placeholderTextColor" not in current_style: self.setStyleSheet(current_style + placeholder_style)
1239
  else: self.setStyleSheet(current_style.replace(placeholder_style, ""))
1240
 
1241
  def toggle_theme(self):
@@ -1480,14 +2991,14 @@ class CaptionApp(QMainWindow):
1480
 
1481
  # Track results (Approximate - relies on _save reporting skips)
1482
  if caption_result is not None:
1483
- processed_count += 1
1484
  else:
1485
  # If None, assume error unless status bar indicates skip (imperfect)
1486
  if "Skipped saving" not in self.status_bar.currentMessage():
1487
- error_count += 1
1488
  # No reliable way to count skips here without modifying _save return value
1489
 
1490
- self.progress_bar.setValue(i + 1)
1491
  QApplication.processEvents()
1492
 
1493
  # Restore overwrite checkbox state if changed
@@ -1495,7 +3006,7 @@ class CaptionApp(QMainWindow):
1495
 
1496
  self.progress_bar.hide()
1497
  final_message = f"Batch finished. {processed_count} captions generated/saved."
1498
- if error_count > 0: final_message += f" {error_count} errors."
1499
  # Cannot reliably report skips here
1500
  QMessageBox.information(self, "Batch Complete", final_message)
1501
  self.show_status(final_message, 10000)
@@ -1537,4 +3048,5 @@ if __name__ == "__main__":
1537
  app.setStyle("Fusion") # Optional
1538
  window = CaptionApp()
1539
  window.show()
1540
- sys.exit(app.exec_())
 
 
1
+ diff --git "a/Run_gui.py" "b/Run_gui.py"
2
+ --- "a/Run_gui.py"
3
+ "b/Run_gui.py"
4
+ @@ -1,1523 1,1525 @@
5
+ -import sys
6
+ -import os
7
+ -import torch
8
+ -from torch import nn
9
+ -from transformers import (
10
+ - AutoModel,
11
+ - AutoProcessor,
12
+ - AutoTokenizer,
13
+ - PreTrainedTokenizer,
14
+ - PreTrainedTokenizerFast,
15
+ - AutoModelForCausalLM,
16
+ - BitsAndBytesConfig,
17
+ -)
18
+ -from PIL import Image
19
+ -import torchvision.transforms.functional as TVF
20
+ -import contextlib
21
+ -from typing import Union, List
22
+ -from pathlib import Path
23
+ -import re
24
+ -
25
+ -from PyQt5.QtWidgets import (
26
+ - QApplication,
27
+ - QWidget,
28
+ - QLabel,
29
+ - QPushButton,
30
+ - QFileDialog,
31
+ - QLineEdit,
32
+ - QTextEdit,
33
+ - QComboBox,
34
+ - QVBoxLayout,
35
+ - QHBoxLayout,
36
+ - QCheckBox,
37
+ - QListWidget,
38
+ - QListWidgetItem,
39
+ - QMessageBox,
40
+ - QSizePolicy,
41
+ - QStatusBar,
42
+ - QProgressBar,
43
+ - QMainWindow,
44
+ -)
45
+ -from PyQt5.QtGui import QPixmap, QIcon
46
+ -from PyQt5.QtCore import Qt, QTimer
47
+ -
48
+ -# --- Constants and Mappings ---
49
+ -CLIP_PATH = "google/siglip-so400m-patch14-384"
50
+ -CAPTION_TYPE_MAP = {
51
+ - "Descriptive": [
52
+ - "Write a descriptive caption for this image in a formal tone.",
53
+ - "Write a descriptive caption for this image in a formal tone within {word_count} words.",
54
+ - "Write a {length} descriptive caption for this image in a formal tone.",
55
+ - ],
56
+ - "Descriptive (Informal)": [
57
+ - "Write a descriptive caption for this image in a casual tone.",
58
+ - "Write a descriptive caption for this image in a casual tone within {word_count} words.",
59
+ - "Write a {length} descriptive caption for this image in a casual tone.",
60
+ - ],
61
+ - "Training Prompt": [
62
+ - "Write a stable diffusion prompt for this image.",
63
+ - "Write a stable diffusion prompt for this image within {word_count} words.",
64
+ - "Write a {length} stable diffusion prompt for this image.",
65
+ - ],
66
+ - "MidJourney": [
67
+ - "Write a MidJourney prompt for this image.",
68
+ - "Write a MidJourney prompt for this image within {word_count} words.",
69
+ - "Write a {length} MidJourney prompt for this image.",
70
+ - ],
71
+ - "Booru tag list": [
72
+ - "Write a list of Booru tags for this image.",
73
+ - "Write a list of Booru tags for this image within {word_count} words.",
74
+ - "Write a {length} list of Booru tags for this image.",
75
+ - ],
76
+ - "Booru-like tag list": [
77
+ - "Write a list of Booru-like tags for this image.",
78
+ - "Write a list of Booru-like tags for this image within {word_count} words.",
79
+ - "Write a {length} list of Booru-like tags for this image.",
80
+ - ],
81
+ - "Art Critic": [
82
+ - "Analyze this image like an art critic would with information about its composition, style, symbolism, the use of color, light, any artistic movement it might belong to, etc.",
83
+ - "Analyze this image like an art critic would with information about its composition, style, symbolism, the use of color, light, any artistic movement it might belong to, etc. Keep it within {word_count} words.",
84
+ - "Analyze this image like an art critic would with information about its composition, style, symbolism, the use of color, light, any artistic movement it might belong to, etc. Keep it {length}.",
85
+ - ],
86
+ - "Product Listing": [
87
+ - "Write a caption for this image as though it were a product listing.",
88
+ - "Write a caption for this image as though it were a product listing. Keep it under {word_count} words.",
89
+ - "Write a {length} caption for this image as though it were a product listing.",
90
+ - ],
91
+ - "Social Media Post": [
92
+ - "Write a caption for this image as if it were being used for a social media post.",
93
+ - "Write a caption for this image as if it were being used for a social media post. Limit the caption to {word_count} words.",
94
+ - "Write a {length} caption for this image as if it were being used for a social media post.",
95
+ - ],
96
+ -}
97
+ -
98
+ -EXTRA_OPTIONS_LIST = [
99
+ - "If there is a person/character in the image you must refer to them as {name}.",
100
+ - "Do NOT include information about people/characters that cannot be changed (like ethnicity, gender, etc), but do still include changeable attributes (like hair style).",
101
+ - "Include information about lighting.",
102
+ - "Include information about camera angle.",
103
+ - "Include information about whether there is a watermark or not.",
104
+ - "Include information about whether there are JPEG artifacts or not.",
105
+ - "If it is a photo you MUST include information about what camera was likely used and details such as aperture, shutter speed, ISO, etc.",
106
+ - "Do NOT include anything sexual; keep it PG.",
107
+ - "Do NOT mention the image's resolution.",
108
+ - "You MUST include information about the subjective aesthetic quality of the image from low to very high.",
109
+ - "Include information on the image's composition style, such as leading lines, rule of thirds, or symmetry.",
110
+ - "Do NOT mention any text that is in the image.",
111
+ - "Specify the depth of field and whether the background is in focus or blurred.",
112
+ - "If applicable, mention the likely use of artificial or natural lighting sources.",
113
+ - "Do NOT use any ambiguous language.",
114
+ - "Include whether the image is sfw, suggestive, or nsfw.",
115
+ - "ONLY describe the most important elements of the image.",
116
+ -]
117
+ -
118
+ -CAPTION_LENGTH_CHOICES = (
119
+ - ["any", "very short", "short", "medium-length", "long", "very long"]
120
+ - [str(i) for i in range(20, 261, 10)]
121
+ -)
122
+ -
123
+ -HF_TOKEN = os.environ.get("HF_TOKEN", None)
124
+ -
125
+ -# --- Device and Autocast Setup ---
126
+ -device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
127
+ -if device.type == "cuda":
128
+ - torch_dtype = torch.bfloat16
129
+ -else:
130
+ - torch_dtype = torch.float32
131
+ -
132
+ -if device.type == "cuda":
133
+ - autocast = lambda: torch.amp.autocast(device_type='cuda', dtype=torch_dtype)
134
+ -else:
135
+ - autocast = contextlib.nullcontext
136
+ -
137
+ -# --- ImageAdapter Class ---
138
+ -class ImageAdapter(nn.Module):
139
+ - def __init__(
140
+ - self,
141
+ - input_features: int,
142
+ - output_features: int,
143
+ - ln1: bool,
144
+ - pos_emb: bool,
145
+ - num_image_tokens: int,
146
+ - deep_extract: bool,
147
+ - ):
148
+ - super().__init__()
149
+ - self.deep_extract = deep_extract
150
+ -
151
+ - if self.deep_extract:
152
+ - input_features = input_features * 5
153
+ -
154
+ - self.linear1 = nn.Linear(input_features, output_features)
155
+ - self.activation = nn.GELU()
156
+ - self.linear2 = nn.Linear(output_features, output_features)
157
+ - self.ln1 = nn.Identity() if not ln1 else nn.LayerNorm(input_features)
158
+ - self.pos_emb = (
159
+ - None if not pos_emb else nn.Parameter(torch.zeros(num_image_tokens, input_features))
160
+ - )
161
+ -
162
+ - # Other tokens (<|image_start|>, <|image_end|>, <|eot_id|>)
163
+ - self.other_tokens = nn.Embedding(3, output_features)
164
+ - self.other_tokens.weight.data.normal_(
165
+ - mean=0.0, std=0.02
166
+ - )
167
+ -
168
+ - def forward(self, vision_outputs: torch.Tensor):
169
+ - if self.deep_extract:
170
+ - x = torch.concat(
171
+ - (
172
+ - vision_outputs[-2],
173
+ - vision_outputs[3],
174
+ - vision_outputs[7],
175
+ - vision_outputs[13],
176
+ - vision_outputs[20],
177
+ - ),
178
+ - dim=-1,
179
+ - )
180
+ - assert len(x.shape) == 3
181
+ - assert x.shape[-1] == vision_outputs[-2].shape[-1] * 5
182
+ - else:
183
+ - x = vision_outputs[-2]
184
+ -
185
+ - x = self.ln1(x)
186
+ -
187
+ - if self.pos_emb is not None:
188
+ - assert x.shape[-2:] == self.pos_emb.shape
189
+ - x = x self.pos_emb
190
+ -
191
+ - x = self.linear1(x)
192
+ - x = self.activation(x)
193
+ - x = self.linear2(x)
194
+ -
195
+ - other_tokens = self.other_tokens(
196
+ - torch.tensor([0, 1], device=self.other_tokens.weight.device).expand(x.shape[0], -1)
197
+ - )
198
+ - assert other_tokens.shape == (x.shape[0], 2, x.shape[2])
199
+ - x = torch.cat((other_tokens[:, 0:1], x, other_tokens[:, 1:2]), dim=1)
200
+ -
201
+ - return x
202
+ -
203
+ - def get_eot_embedding(self):
204
+ - return self.other_tokens(torch.tensor([2], device=self.other_tokens.weight.device)).squeeze(0)
205
+ -
206
+ -# --- load_models Function ---
207
+ -def load_models(CHECKPOINT_PATH, status_callback=None):
208
+ - def update_status(msg):
209
+ - if status_callback:
210
+ - status_callback(msg)
211
+ - print(msg) # Keep console output
212
+ -
213
+ - update_status("Loading CLIP processor...")
214
+ - clip_processor = AutoProcessor.from_pretrained(CLIP_PATH)
215
+ - update_status("Loading CLIP vision model...")
216
+ - clip_model = AutoModel.from_pretrained(CLIP_PATH)
217
+ - clip_model = clip_model.vision_model
218
+ -
219
+ - clip_model_path = CHECKPOINT_PATH / "clip_model.pt"
220
+ - if not clip_model_path.exists():
221
+ - raise FileNotFoundError(f"clip_model.pt not found in {CHECKPOINT_PATH}")
222
+ -
223
+ - update_status("Loading VLM's custom vision weights...")
224
+ - checkpoint = torch.load(clip_model_path, map_location="cpu")
225
+ - checkpoint = {k.replace("_orig_mod.module.", ""): v for k, v in checkpoint.items()}
226
+ - clip_model.load_state_dict(checkpoint)
227
+ - del checkpoint
228
+ -
229
+ - clip_model.eval()
230
+ - clip_model.requires_grad_(False)
231
+ - update_status(f"Moving CLIP to {device}...")
232
+ - clip_model.to(device)
233
+ -
234
+ - update_status("Loading tokenizer...")
235
+ - tokenizer = AutoTokenizer.from_pretrained(
236
+ - CHECKPOINT_PATH / "text_model", use_fast=True
237
+ - )
238
+ - if not isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
239
+ - raise TypeError(f"Tokenizer is of type {type(tokenizer)}")
240
+ -
241
+ - special_tokens_dict = {'additional_special_tokens': ['<|system|>', '<|user|>', '<|end|>', '<|eot_id|>']}
242
+ - num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
243
+ - update_status(f"Added {num_added_toks} special tokens.")
244
+ -
245
+ - update_status("Loading LLM with 4-bit quantization (this may take time)...")
246
+ - text_model = AutoModelForCausalLM.from_pretrained(
247
+ - CHECKPOINT_PATH / "text_model",
248
+ - device_map="auto",
249
+ - quantization_config=BitsAndBytesConfig(
250
+ - load_in_4bit=True,
251
+ - bnb_4bit_use_double_quant=True,
252
+ - bnb_4bit_quant_type='nf4',
253
+ - bnb_4bit_compute_dtype=torch.float16
254
+ - )
255
+ - )
256
+ - text_model.eval()
257
+ -
258
+ - if num_added_toks > 0:
259
+ - update_status("Resizing LLM token embeddings...")
260
+ - text_model.resize_token_embeddings(len(tokenizer))
261
+ -
262
+ - update_status("Loading image adapter...")
263
+ - image_adapter = ImageAdapter(
264
+ - clip_model.config.hidden_size, text_model.config.hidden_size, False, False, 38, False
265
+ - )
266
+ - image_adapter_path = CHECKPOINT_PATH / "image_adapter.pt"
267
+ - if not image_adapter_path.exists():
268
+ - raise FileNotFoundError(f"image_adapter.pt not found in {CHECKPOINT_PATH}")
269
+ -
270
+ - image_adapter.load_state_dict(
271
+ - torch.load(image_adapter_path, map_location="cpu")
272
+ - )
273
+ - image_adapter.eval()
274
+ - update_status(f"Moving image adapter to {device}...")
275
+ - image_adapter.to(device)
276
+ -
277
+ - update_status("Models loaded successfully.")
278
+ - return clip_processor, clip_model, tokenizer, text_model, image_adapter
279
+ -
280
+ -# --- generate_caption Function ---
281
+ -@torch.no_grad()
282
+ -def generate_caption(
283
+ - input_image: Image.Image,
284
+ - caption_type: str,
285
+ - caption_length: Union[str, int],
286
+ - extra_options: List[str],
287
+ - name_input: str,
288
+ - custom_prompt: str,
289
+ - clip_model,
290
+ - tokenizer,
291
+ - text_model,
292
+ - image_adapter,
293
+ -) -> tuple:
294
+ - if device.type == "cuda":
295
+ - torch.cuda.empty_cache()
296
+ -
297
+ - if custom_prompt.strip() != "":
298
+ - prompt_str = custom_prompt.strip()
299
+ - else:
300
+ - length = None if caption_length == "any" else caption_length
301
+ - if isinstance(length, str):
302
+ - try:
303
+ - length = int(length)
304
+ - except ValueError:
305
+ - pass
306
+ -
307
+ - if length is None: map_idx = 0
308
+ - elif isinstance(length, int): map_idx = 1
309
+ - elif isinstance(length, str): map_idx = 2
310
+ - else: raise ValueError(f"Invalid caption length: {length}")
311
+ -
312
+ - prompt_str = CAPTION_TYPE_MAP[caption_type][map_idx]
313
+ - if len(extra_options) > 0: prompt_str = " " " ".join(extra_options)
314
+ - prompt_str = prompt_str.format(name=name_input, length=caption_length, word_count=caption_length)
315
+ -
316
+ - print(f"Prompt: {prompt_str}")
317
+ -
318
+ - try:
319
+ - image = input_image.convert("RGB")
320
+ - except Exception as e: raise ValueError(f"Error converting image to RGB: {e}")
321
+ - if image.mode != "RGB": raise ValueError(f"Image mode after conversion is {image.mode}, expected 'RGB'.")
322
+ -
323
+ - image = image.resize((384, 384), Image.LANCZOS)
324
+ - pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
325
+ - pixel_values = TVF.normalize(pixel_values, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
326
+ - pixel_values = pixel_values.to(device)
327
+ -
328
+ - with autocast():
329
+ - vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True)
330
+ - embedded_images = image_adapter(vision_outputs.hidden_states)
331
+ - embedded_images = embedded_images.to(device)
332
+ -
333
+ - convo = [
334
+ - {"role": "system", "content": "You are a helpful image captioner."},
335
+ - {"role": "user", "content": prompt_str},
336
+ - ]
337
+ -
338
+ - if hasattr(tokenizer, "apply_chat_template"):
339
+ - convo_string = tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
340
+ - else:
341
+ - convo_string = ("<|system|>\n" convo[0]["content"] "\n<|end|>\n<|user|>\n" convo[1]["content"] "\n<|end|>\n")
342
+ - assert isinstance(convo_string, str)
343
+ -
344
+ - convo_tokens = tokenizer.encode(convo_string, return_tensors="pt", add_special_tokens=False, truncation=False).to(device)
345
+ - prompt_tokens = tokenizer.encode(prompt_str, return_tensors="pt", add_special_tokens=False, truncation=False).to(device)
346
+ - assert isinstance(convo_tokens, torch.Tensor) and isinstance(prompt_tokens, torch.Tensor)
347
+ - convo_tokens = convo_tokens.squeeze(0)
348
+ - prompt_tokens = prompt_tokens.squeeze(0)
349
+ -
350
+ - end_token_id = tokenizer.convert_tokens_to_ids("<|end|>")
351
+ - if end_token_id is None: raise ValueError("Tokenizer missing '<|end|>' token.")
352
+ - end_token_indices = (convo_tokens == end_token_id).nonzero(as_tuple=True)[0].tolist()
353
+ - preamble_len = end_token_indices[0] 1 if len(end_token_indices) >= 1 else 0
354
+ -
355
+ - convo_embeds = text_model.model.embed_tokens(convo_tokens.unsqueeze(0).to(device))
356
+ - input_embeds = torch.cat([
357
+ - convo_embeds[:, :preamble_len],
358
+ - embedded_images.to(dtype=convo_embeds.dtype),
359
+ - convo_embeds[:, preamble_len:],
360
+ - ], dim=1).to(device)
361
+ -
362
+ - input_ids = torch.cat([
363
+ - convo_tokens[:preamble_len].unsqueeze(0),
364
+ - torch.full((1, embedded_images.shape[1]), tokenizer.pad_token_id, dtype=torch.long, device=device),
365
+ - convo_tokens[preamble_len:].unsqueeze(0),
366
+ - ], dim=1).to(device)
367
+ - attention_mask = torch.ones_like(input_ids).to(device)
368
+ -
369
+ - print(f"Input to model: {repr(tokenizer.decode(input_ids[0]))}")
370
+ -
371
+ - generate_ids = text_model.generate(
372
+ - input_ids=input_ids, inputs_embeds=input_embeds, attention_mask=attention_mask,
373
+ - max_new_tokens=300, do_sample=True, temperature=0.6, top_p=0.9,
374
+ - suppress_tokens=None, eos_token_id=[tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|end|>")]
375
+ - )
376
+ -
377
+ - generate_ids = generate_ids[:, input_ids.shape[1]:]
378
+ - caption = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]
379
+ - caption = caption.strip()
380
+ - caption = re.sub(r'\s', ' ', caption)
381
+ -
382
+ - return prompt_str, caption
383
+ -
384
+ -class CaptionApp(QMainWindow):
385
+ - def __init__(self):
386
+ - # ... (constructor unchanged) ...
387
+ - super().__init__()
388
+ - self.setWindowTitle("JoyCaption Alpha Two - Enhanced")
389
+ - self.setGeometry(100, 100, 1200, 850)
390
+ - self.setMinimumSize(1000, 750)
391
+ -
392
+ - self.clip_processor = None
393
+ - self.clip_model = None
394
+ - self.tokenizer = None
395
+ - self.text_model = None
396
+ - self.image_adapter = None
397
+ - self.models_loaded = False
398
+ -
399
+ - self.input_dir = None
400
+ - self.single_image_path = None
401
+ - self.selected_image_path = None
402
+ - self.image_files = []
403
+ -
404
+ - self.dark_mode = False
405
+ -
406
+ - self.central_widget = QWidget()
407
+ - self.setCentralWidget(self.central_widget)
408
+ - self.main_layout = QHBoxLayout(self.central_widget)
409
+ -
410
+ - self.initUI() # Call initUI
411
+ - self.update_button_states()
412
+ - self.apply_theme()
413
+ -
414
+ -
415
+ - def initUI(self):
416
+ - # --- Left Panel ---
417
+ - left_panel = QVBoxLayout()
418
+ - left_panel.setSpacing(10)
419
+ -
420
+ - # Input directory selection
421
+ - dir_layout = QHBoxLayout()
422
+ - self.input_dir_button = QPushButton("Select Input Directory")
423
+ - self.input_dir_button.setToolTip("Select a folder containing images to process in batch.")
424
+ - self.input_dir_button.clicked.connect(self.select_input_directory)
425
+ - dir_layout.addWidget(self.input_dir_button)
426
+ - self.input_dir_label = QLabel("No directory selected")
427
+ - self.input_dir_label.setWordWrap(True)
428
+ - dir_layout.addWidget(self.input_dir_label, 1)
429
+ - left_panel.addLayout(dir_layout)
430
+ -
431
+ - # Single image selection
432
+ - single_img_layout = QHBoxLayout()
433
+ - self.single_image_button = QPushButton("Select Single Image")
434
+ - self.single_image_button.setToolTip("Select one image file to process.")
435
+ - self.single_image_button.clicked.connect(self.select_single_image)
436
+ - single_img_layout.addWidget(self.single_image_button)
437
+ - self.single_image_label = QLabel("No image selected")
438
+ - self.single_image_label.setWordWrap(True)
439
+ - single_img_layout.addWidget(self.single_image_label, 1)
440
+ - left_panel.addLayout(single_img_layout)
441
+ -
442
+ - # Caption Type
443
+ - self.caption_type_combo = QComboBox()
444
+ - self.caption_type_combo.addItems(CAPTION_TYPE_MAP.keys())
445
+ - self.caption_type_combo.setCurrentText("Descriptive")
446
+ - self.caption_type_combo.setToolTip("Choose the style or purpose of the caption.")
447
+ - left_panel.addWidget(QLabel("Caption Type:"))
448
+ - left_panel.addWidget(self.caption_type_combo)
449
+ -
450
+ - # Caption Length
451
+ - self.caption_length_combo = QComboBox()
452
+ - self.caption_length_combo.addItems(CAPTION_LENGTH_CHOICES)
453
+ - self.caption_length_combo.setCurrentText("long")
454
+ - self.caption_length_combo.setToolTip("Select desired caption length or word count.")
455
+ - left_panel.addWidget(QLabel("Caption Length:"))
456
+ - left_panel.addWidget(self.caption_length_combo)
457
+ -
458
+ - # Extra Options
459
+ - left_panel.addWidget(QLabel("Extra Options:"))
460
+ - self.extra_options_checkboxes = []
461
+ - for option in EXTRA_OPTIONS_LIST:
462
+ - checkbox = QCheckBox(option)
463
+ - checkbox.setToolTip(option)
464
+ - self.extra_options_checkboxes.append(checkbox)
465
+ - left_panel.addWidget(checkbox)
466
+ -
467
+ - # Name Input
468
+ - self.name_input_line = QLineEdit()
469
+ - self.name_input_line.setPlaceholderText("e.g., 'the main character'")
470
+ - self.name_input_line.setToolTip("If the first extra option is checked, this name will be used.")
471
+ - left_panel.addWidget(QLabel("Person/Character Name (optional):"))
472
+ - left_panel.addWidget(self.name_input_line)
473
+ -
474
+ - # Custom Prompt
475
+ - self.custom_prompt_text = QTextEdit()
476
+ - self.custom_prompt_text.setPlaceholderText("Overrides Caption Type/Length/Options if used.")
477
+ - self.custom_prompt_text.setToolTip("Enter a full custom prompt here to ignore other settings.")
478
+ - self.custom_prompt_text.setFixedHeight(80)
479
+ - left_panel.addWidget(QLabel("Custom Prompt (optional):"))
480
+ - left_panel.addWidget(self.custom_prompt_text)
481
+ -
482
+ - # Checkpoint Path
483
+ - ckpt_layout = QHBoxLayout()
484
+ - self.checkpoint_path_line = QLineEdit()
485
+ - self.checkpoint_path_line.setToolTip("Path to the folder containing model files (clip_model.pt, etc.).")
486
+ - ckpt_layout.addWidget(QLabel("Checkpoint Path:"))
487
+ - ckpt_layout.addWidget(self.checkpoint_path_line)
488
+ - self.browse_ckpt_button = QPushButton("...")
489
+ - self.browse_ckpt_button.setToolTip("Browse for Checkpoint Directory")
490
+ - self.browse_ckpt_button.clicked.connect(self.browse_checkpoint_path)
491
+ - self.browse_ckpt_button.setMaximumWidth(30)
492
+ - ckpt_layout.addWidget(self.browse_ckpt_button)
493
+ - left_panel.addLayout(ckpt_layout)
494
+ -
495
+ - # Load Models Button
496
+ - self.load_models_button = QPushButton("Load Models")
497
+ - self.load_models_button.setToolTip("Load the AI models into memory (requires checkpoint path).")
498
+ - self.load_models_button.clicked.connect(self.load_models_action)
499
+ - left_panel.addWidget(self.load_models_button)
500
+ -
501
+ -
502
+ -
503
+ - # Run Buttons
504
+ - self.run_button = QPushButton("Generate Captions for All Images in Directory")
505
+ - self.run_button.setToolTip("Process all loaded images from the selected directory.")
506
+ - self.run_button.clicked.connect(self.generate_captions_action)
507
+ - left_panel.addWidget(self.run_button)
508
+ -
509
+ - self.caption_selected_button = QPushButton("Caption Selected Image from List")
510
+ - self.caption_selected_button.setToolTip("Process the image currently highlighted in the list.")
511
+ - self.caption_selected_button.clicked.connect(self.caption_selected_image_action)
512
+ - left_panel.addWidget(self.caption_selected_button)
513
+ -
514
+ - self.caption_single_button = QPushButton("Caption Single Loaded Image")
515
+ - self.caption_single_button.setToolTip("Process the image selected via 'Select Single Image'.")
516
+ - self.caption_single_button.clicked.connect(self.caption_single_image_action)
517
+ - left_panel.addWidget(self.caption_single_button)
518
+ -
519
+ - # Theme Toggle Button
520
+ - self.toggle_theme_button = QPushButton("Toggle Dark Mode")
521
+ - self.toggle_theme_button.setToolTip("Switch between light and dark themes.")
522
+ - self.toggle_theme_button.clicked.connect(self.toggle_theme)
523
+ - left_panel.addWidget(self.toggle_theme_button)
524
+ -
525
+ - left_panel.addStretch(1)
526
+ -
527
+ - # --- Right Panel ---
528
+ - right_panel = QVBoxLayout()
529
+ - right_panel.setSpacing(10)
530
+ -
531
+ - # List widget for images
532
+ - right_panel.addWidget(QLabel("Images in Directory:"))
533
+ - self.image_list_widget = QListWidget()
534
+ - self.image_list_widget.setIconSize(self.image_list_widget.iconSize() * 2)
535
+ - self.image_list_widget.itemClicked.connect(self.display_selected_image)
536
+ - self.image_list_widget.setToolTip("Click an image to view it and enable 'Caption Selected Image'.")
537
+ - right_panel.addWidget(self.image_list_widget, 1)
538
+ -
539
+ - # Label to display the selected image
540
+ - right_panel.addWidget(QLabel("Selected Image Preview:"))
541
+ - self.selected_image_label = QLabel("No image selected")
542
+ - self.selected_image_label.setAlignment(Qt.AlignCenter)
543
+ - self.selected_image_label.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
544
+ - self.selected_image_label.setMinimumSize(300, 300)
545
+ - self.selected_image_label.setStyleSheet("border: 1px solid gray;")
546
+ - right_panel.addWidget(self.selected_image_label, 3)
547
+ -
548
+ - # Generated Caption Area
549
+ - right_panel.addWidget(QLabel("Generated/Editable Caption:"))
550
+ - self.generated_caption_text = QTextEdit()
551
+ - self.generated_caption_text.setReadOnly(False)
552
+ - self.generated_caption_text.setPlaceholderText("Generated caption will appear here. You can edit it before saving.")
553
+ - self.generated_caption_text.setToolTip("The generated caption appears here. Edit and use 'Save Edited Caption'.")
554
+ - right_panel.addWidget(self.generated_caption_text, 1)
555
+ -
556
+ -
557
+ -
558
+ - self.overwrite_checkbox = QCheckBox("Overwrite existing captions")
559
+ - self.overwrite_checkbox.setToolTip("If checked, automatically overwrites existing .txt files without asking.")
560
+ - self.append_checkbox = QCheckBox("Append to existing captions")
561
+ - self.append_checkbox.setToolTip("If checked, adds the new caption to the end of the existing .txt file.")
562
+ -
563
+ - # Layout for the save options
564
+ - save_options_layout = QHBoxLayout()
565
+ - save_options_layout.addWidget(self.overwrite_checkbox)
566
+ - save_options_layout.addWidget(self.append_checkbox)
567
+ - save_options_layout.addStretch(1)
568
+ - right_panel.addLayout(save_options_layout)
569
+ -
570
+ -
571
+ - self.append_checkbox.stateChanged.connect(
572
+ - lambda state: self.overwrite_checkbox.setEnabled(state == Qt.Unchecked)
573
+ - )
574
+ -
575
+ -
576
+ - # Save Edited Caption Button
577
+ - self.save_caption_button = QPushButton("Save Edited Caption to File")
578
+ - self.save_caption_button.setToolTip("Save the text currently in the box above to the corresponding .txt file using the selected options.")
579
+ - self.save_caption_button.clicked.connect(self.save_edited_caption_action)
580
+ - right_panel.addWidget(self.save_caption_button)
581
+ -
582
+ - # --- Main Layout Assembly
583
+ - self.main_layout.addLayout(left_panel, 2)
584
+ - self.main_layout.addLayout(right_panel, 5)
585
+ -
586
+ - # --- Status Bar and Progress Bar
587
+ - self.status_bar = QStatusBar()
588
+ - self.setStatusBar(self.status_bar)
589
+ - self.progress_bar = QProgressBar()
590
+ - self.status_bar.addPermanentWidget(self.progress_bar)
591
+ - self.progress_bar.hide()
592
+ - self.show_status("Ready.", 5000)
593
+ -
594
+ -
595
+ -
596
+ -import sys
597
+ -import os
598
+ -import torch
599
+ -from torch import nn
600
+ -from transformers import (
601
+ - AutoModel,
602
+ - AutoProcessor,
603
+ - AutoTokenizer,
604
+ - PreTrainedTokenizer,
605
+ - PreTrainedTokenizerFast,
606
+ - AutoModelForCausalLM,
607
+ - BitsAndBytesConfig,
608
+ -)
609
+ -from PIL import Image
610
+ -import torchvision.transforms.functional as TVF
611
+ -import contextlib
612
+ -from typing import Union, List
613
+ -from pathlib import Path
614
+ -import re # Added for spacing fix
615
+ -
616
+ -from PyQt5.QtWidgets import (
617
+ - QApplication,
618
+ - QWidget,
619
+ - QLabel,
620
+ - QPushButton,
621
+ - QFileDialog,
622
+ - QLineEdit,
623
+ - QTextEdit,
624
+ - QComboBox,
625
+ - QVBoxLayout,
626
+ - QHBoxLayout,
627
+ - QCheckBox,
628
+ - QListWidget,
629
+ - QListWidgetItem,
630
+ - QMessageBox,
631
+ - QSizePolicy,
632
+ - QStatusBar,
633
+ - QProgressBar,
634
+ - QMainWindow,
635
+ -)
636
+ -from PyQt5.QtGui import QPixmap, QIcon
637
+ -from PyQt5.QtCore import Qt, QTimer
638
+ -
639
+ -# --- Constants and Mappings ---
640
+ -CLIP_PATH = "google/siglip-so400m-patch14-384"
641
+ -CAPTION_TYPE_MAP = {
642
+ - "Descriptive": [
643
+ - "Write a descriptive caption for this image in a formal tone.",
644
+ - "Write a descriptive caption for this image in a formal tone within {word_count} words.",
645
+ - "Write a {length} descriptive caption for this image in a formal tone.",
646
+ - ],
647
+ - "Descriptive (Informal)": [
648
+ - "Write a descriptive caption for this image in a casual tone.",
649
+ - "Write a descriptive caption for this image in a casual tone within {word_count} words.",
650
+ - "Write a {length} descriptive caption for this image in a casual tone.",
651
+ - ],
652
+ - "Training Prompt": [
653
+ - "Write a stable diffusion prompt for this image.",
654
+ - "Write a stable diffusion prompt for this image within {word_count} words.",
655
+ - "Write a {length} stable diffusion prompt for this image.",
656
+ - ],
657
+ - "MidJourney": [
658
+ - "Write a MidJourney prompt for this image.",
659
+ - "Write a MidJourney prompt for this image within {word_count} words.",
660
+ - "Write a {length} MidJourney prompt for this image.",
661
+ - ],
662
+ - "Booru tag list": [
663
+ - "Write a list of Booru tags for this image.",
664
+ - "Write a list of Booru tags for this image within {word_count} words.",
665
+ - "Write a {length} list of Booru tags for this image.",
666
+ - ],
667
+ - "Booru-like tag list": [
668
+ - "Write a list of Booru-like tags for this image.",
669
+ - "Write a list of Booru-like tags for this image within {word_count} words.",
670
+ - "Write a {length} list of Booru-like tags for this image.",
671
+ - ],
672
+ - "Art Critic": [
673
+ - "Analyze this image like an art critic would with information about its composition, style, symbolism, the use of color, light, any artistic movement it might belong to, etc.",
674
+ - "Analyze this image like an art critic would with information about its composition, style, symbolism, the use of color, light, any artistic movement it might belong to, etc. Keep it within {word_count} words.",
675
+ - "Analyze this image like an art critic would with information about its composition, style, symbolism, the use of color, light, any artistic movement it might belong to, etc. Keep it {length}.",
676
+ - ],
677
+ - "Product Listing": [
678
+ - "Write a caption for this image as though it were a product listing.",
679
+ - "Write a caption for this image as though it were a product listing. Keep it under {word_count} words.",
680
+ - "Write a {length} caption for this image as though it were a product listing.",
681
+ - ],
682
+ - "Social Media Post": [
683
+ - "Write a caption for this image as if it were being used for a social media post.",
684
+ - "Write a caption for this image as if it were being used for a social media post. Limit the caption to {word_count} words.",
685
+ - "Write a {length} caption for this image as if it were being used for a social media post.",
686
+ - ],
687
+ -}
688
+ -
689
+ -EXTRA_OPTIONS_LIST = [
690
+ - "If there is a person/character in the image you must refer to them as {name}.",
691
+ - "Do NOT include information about people/characters that cannot be changed (like ethnicity, gender, etc), but do still include changeable attributes (like hair style).",
692
+ - "Include information about lighting.",
693
+ - "Include information about camera angle.",
694
+ - "Include information about whether there is a watermark or not.",
695
+ - "Include information about whether there are JPEG artifacts or not.",
696
+ - "If it is a photo you MUST include information about what camera was likely used and details such as aperture, shutter speed, ISO, etc.",
697
+ - "Do NOT include anything sexual; keep it PG.",
698
+ - "Do NOT mention the image's resolution.",
699
+ - "You MUST include information about the subjective aesthetic quality of the image from low to very high.",
700
+ - "Include information on the image's composition style, such as leading lines, rule of thirds, or symmetry.",
701
+ - "Do NOT mention any text that is in the image.",
702
+ - "Specify the depth of field and whether the background is in focus or blurred.",
703
+ - "If applicable, mention the likely use of artificial or natural lighting sources.",
704
+ - "Do NOT use any ambiguous language.",
705
+ - "Include whether the image is sfw, suggestive, or nsfw.",
706
+ - "ONLY describe the most important elements of the image.",
707
+ -]
708
+ -
709
+ -CAPTION_LENGTH_CHOICES = (
710
+ - ["any", "very short", "short", "medium-length", "long", "very long"]
711
+ - [str(i) for i in range(20, 261, 10)]
712
+ -)
713
+ -
714
+ -HF_TOKEN = os.environ.get("HF_TOKEN", None)
715
+ -
716
+ -# --- Device and Autocast Setup ---
717
+ -device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
718
+ -if device.type == "cuda":
719
+ - torch_dtype = torch.bfloat16
720
+ -else:
721
+ - torch_dtype = torch.float32
722
+ -
723
+ -if device.type == "cuda":
724
+ - autocast = lambda: torch.amp.autocast(device_type='cuda', dtype=torch_dtype)
725
+ -else:
726
+ - autocast = contextlib.nullcontext
727
+ -
728
+ -# --- ImageAdapter Class ---
729
+ -class ImageAdapter(nn.Module):
730
+ - def __init__(
731
+ - self,
732
+ - input_features: int,
733
+ - output_features: int,
734
+ - ln1: bool,
735
+ - pos_emb: bool,
736
+ - num_image_tokens: int,
737
+ - deep_extract: bool,
738
+ - ):
739
+ - super().__init__()
740
+ - self.deep_extract = deep_extract
741
+ -
742
+ - if self.deep_extract:
743
+ - input_features = input_features * 5
744
+ -
745
+ - self.linear1 = nn.Linear(input_features, output_features)
746
+ - self.activation = nn.GELU()
747
+ - self.linear2 = nn.Linear(output_features, output_features)
748
+ - self.ln1 = nn.Identity() if not ln1 else nn.LayerNorm(input_features)
749
+ - self.pos_emb = (
750
+ - None if not pos_emb else nn.Parameter(torch.zeros(num_image_tokens, input_features))
751
+ - )
752
+ -
753
+ - # Other tokens (<|image_start|>, <|image_end|>, <|eot_id|>)
754
+ - self.other_tokens = nn.Embedding(3, output_features)
755
+ - self.other_tokens.weight.data.normal_(
756
+ - mean=0.0, std=0.02
757
+ - )
758
+ -
759
+ - def forward(self, vision_outputs: torch.Tensor):
760
+ - if self.deep_extract:
761
+ - x = torch.concat(
762
+ - (
763
+ - vision_outputs[-2],
764
+ - vision_outputs[3],
765
+ - vision_outputs[7],
766
+ - vision_outputs[13],
767
+ - vision_outputs[20],
768
+ - ),
769
+ - dim=-1,
770
+ - )
771
+ - assert len(x.shape) == 3
772
+ - assert x.shape[-1] == vision_outputs[-2].shape[-1] * 5
773
+ - else:
774
+ - x = vision_outputs[-2]
775
+ -
776
+ - x = self.ln1(x)
777
+ -
778
+ - if self.pos_emb is not None:
779
+ - assert x.shape[-2:] == self.pos_emb.shape
780
+ - x = x self.pos_emb
781
+ -
782
+ - x = self.linear1(x)
783
+ - x = self.activation(x)
784
+ - x = self.linear2(x)
785
+ -
786
+ - other_tokens = self.other_tokens(
787
+ - torch.tensor([0, 1], device=self.other_tokens.weight.device).expand(x.shape[0], -1)
788
+ - )
789
+ - assert other_tokens.shape == (x.shape[0], 2, x.shape[2])
790
+ - x = torch.cat((other_tokens[:, 0:1], x, other_tokens[:, 1:2]), dim=1)
791
+ -
792
+ - return x
793
+ -
794
+ - def get_eot_embedding(self):
795
+ - return self.other_tokens(torch.tensor([2], device=self.other_tokens.weight.device)).squeeze(0)
796
+ -
797
+ -# --- load_models Function ---
798
+ -def load_models(CHECKPOINT_PATH, status_callback=None):
799
+ - def update_status(msg):
800
+ - if status_callback:
801
+ - status_callback(msg)
802
+ - print(msg) # Keep console output
803
+ -
804
+ - update_status("Loading CLIP processor...")
805
+ - clip_processor = AutoProcessor.from_pretrained(CLIP_PATH)
806
+ - update_status("Loading CLIP vision model...")
807
+ - clip_model = AutoModel.from_pretrained(CLIP_PATH)
808
+ - clip_model = clip_model.vision_model
809
+ -
810
+ - clip_model_path = CHECKPOINT_PATH / "clip_model.pt"
811
+ - if not clip_model_path.exists():
812
+ - raise FileNotFoundError(f"clip_model.pt not found in {CHECKPOINT_PATH}")
813
+ -
814
+ - update_status("Loading VLM's custom vision weights...")
815
+ - checkpoint = torch.load(clip_model_path, map_location="cpu")
816
+ - checkpoint = {k.replace("_orig_mod.module.", ""): v for k, v in checkpoint.items()}
817
+ - clip_model.load_state_dict(checkpoint)
818
+ - del checkpoint
819
+ -
820
+ - clip_model.eval()
821
+ - clip_model.requires_grad_(False)
822
+ - update_status(f"Moving CLIP to {device}...")
823
+ - clip_model.to(device)
824
+ -
825
+ - update_status("Loading tokenizer...")
826
+ - tokenizer = AutoTokenizer.from_pretrained(
827
+ - CHECKPOINT_PATH / "text_model", use_fast=True
828
+ - )
829
+ - if not isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
830
+ - raise TypeError(f"Tokenizer is of type {type(tokenizer)}")
831
+ -
832
+ - special_tokens_dict = {'additional_special_tokens': ['<|system|>', '<|user|>', '<|end|>', '<|eot_id|>']}
833
+ - num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
834
+ - update_status(f"Added {num_added_toks} special tokens.")
835
+ -
836
+ - update_status("Loading LLM with 4-bit quantization (this may take time)...")
837
+ - text_model = AutoModelForCausalLM.from_pretrained(
838
+ - CHECKPOINT_PATH / "text_model",
839
+ - device_map="auto",
840
+ - quantization_config=BitsAndBytesConfig(
841
+ - load_in_4bit=True,
842
+ - bnb_4bit_use_double_quant=True,
843
+ - bnb_4bit_quant_type='nf4',
844
+ - bnb_4bit_compute_dtype=torch.float16
845
+ - )
846
+ - )
847
+ - text_model.eval()
848
+ -
849
+ - if num_added_toks > 0:
850
+ - update_status("Resizing LLM token embeddings...")
851
+ - text_model.resize_token_embeddings(len(tokenizer))
852
+ -
853
+ - update_status("Loading image adapter...")
854
+ - image_adapter = ImageAdapter(
855
+ - clip_model.config.hidden_size, text_model.config.hidden_size, False, False, 38, False
856
+ - )
857
+ - image_adapter_path = CHECKPOINT_PATH / "image_adapter.pt"
858
+ - if not image_adapter_path.exists():
859
+ - raise FileNotFoundError(f"image_adapter.pt not found in {CHECKPOINT_PATH}")
860
+ -
861
+ - image_adapter.load_state_dict(
862
+ - torch.load(image_adapter_path, map_location="cpu")
863
+ - )
864
+ - image_adapter.eval()
865
+ - update_status(f"Moving image adapter to {device}...")
866
+ - image_adapter.to(device)
867
+ -
868
+ - update_status("Models loaded successfully.")
869
+ - return clip_processor, clip_model, tokenizer, text_model, image_adapter
870
+ -
871
+ -# --- generate_caption Function ---
872
+ -@torch.no_grad()
873
+ -def generate_caption(
874
+ - input_image: Image.Image,
875
+ - caption_type: str,
876
+ - caption_length: Union[str, int],
877
+ - extra_options: List[str],
878
+ - name_input: str,
879
+ - custom_prompt: str,
880
+ - clip_model,
881
+ - tokenizer,
882
+ - text_model,
883
+ - image_adapter,
884
+ -) -> tuple:
885
+ - if device.type == "cuda":
886
+ - torch.cuda.empty_cache()
887
+ -
888
+ - if custom_prompt.strip() != "":
889
+ - prompt_str = custom_prompt.strip()
890
+ - else:
891
+ - length = None if caption_length == "any" else caption_length
892
+ - if isinstance(length, str):
893
+ - try:
894
+ - length = int(length)
895
+ - except ValueError:
896
+ - pass
897
+ -
898
+ - if length is None: map_idx = 0
899
+ - elif isinstance(length, int): map_idx = 1
900
+ - elif isinstance(length, str): map_idx = 2
901
+ - else: raise ValueError(f"Invalid caption length: {length}")
902
+ -
903
+ - prompt_str = CAPTION_TYPE_MAP[caption_type][map_idx]
904
+ - if len(extra_options) > 0: prompt_str = " " " ".join(extra_options)
905
+ - prompt_str = prompt_str.format(name=name_input, length=caption_length, word_count=caption_length)
906
+ -
907
+ - print(f"Prompt: {prompt_str}")
908
+ -
909
+ - try:
910
+ - image = input_image.convert("RGB")
911
+ - except Exception as e: raise ValueError(f"Error converting image to RGB: {e}")
912
+ - if image.mode != "RGB": raise ValueError(f"Image mode after conversion is {image.mode}, expected 'RGB'.")
913
+ -
914
+ - image = image.resize((384, 384), Image.LANCZOS)
915
+ - pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
916
+ - pixel_values = TVF.normalize(pixel_values, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
917
+ - pixel_values = pixel_values.to(device)
918
+ -
919
+ - with autocast():
920
+ - vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True)
921
+ - embedded_images = image_adapter(vision_outputs.hidden_states)
922
+ - embedded_images = embedded_images.to(device)
923
+ -
924
+ - convo = [
925
+ - {"role": "system", "content": "You are a helpful image captioner."},
926
+ - {"role": "user", "content": prompt_str},
927
+ - ]
928
+ -
929
+ - if hasattr(tokenizer, "apply_chat_template"):
930
+ - convo_string = tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
931
+ - else:
932
+ - convo_string = ("<|system|>\n" convo[0]["content"] "\n<|end|>\n<|user|>\n" convo[1]["content"] "\n<|end|>\n")
933
+ - assert isinstance(convo_string, str)
934
+ -
935
+ - convo_tokens = tokenizer.encode(convo_string, return_tensors="pt", add_special_tokens=False, truncation=False).to(device)
936
+ - prompt_tokens = tokenizer.encode(prompt_str, return_tensors="pt", add_special_tokens=False, truncation=False).to(device)
937
+ - assert isinstance(convo_tokens, torch.Tensor) and isinstance(prompt_tokens, torch.Tensor)
938
+ - convo_tokens = convo_tokens.squeeze(0)
939
+ - prompt_tokens = prompt_tokens.squeeze(0)
940
+ -
941
+ - end_token_id = tokenizer.convert_tokens_to_ids("<|end|>")
942
+ - if end_token_id is None: raise ValueError("Tokenizer missing '<|end|>' token.")
943
+ - end_token_indices = (convo_tokens == end_token_id).nonzero(as_tuple=True)[0].tolist()
944
+ - preamble_len = end_token_indices[0] 1 if len(end_token_indices) >= 1 else 0
945
+ -
946
+ - convo_embeds = text_model.model.embed_tokens(convo_tokens.unsqueeze(0).to(device))
947
+ - input_embeds = torch.cat([
948
+ - convo_embeds[:, :preamble_len],
949
+ - embedded_images.to(dtype=convo_embeds.dtype),
950
+ - convo_embeds[:, preamble_len:],
951
+ - ], dim=1).to(device)
952
+ -
953
+ - input_ids = torch.cat([
954
+ - convo_tokens[:preamble_len].unsqueeze(0),
955
+ - torch.full((1, embedded_images.shape[1]), tokenizer.pad_token_id, dtype=torch.long, device=device),
956
+ - convo_tokens[preamble_len:].unsqueeze(0),
957
+ - ], dim=1).to(device)
958
+ - attention_mask = torch.ones_like(input_ids).to(device)
959
+ -
960
+ - print(f"Input to model: {repr(tokenizer.decode(input_ids[0]))}")
961
+ -
962
+ - generate_ids = text_model.generate(
963
+ - input_ids=input_ids, inputs_embeds=input_embeds, attention_mask=attention_mask,
964
+ - max_new_tokens=300, do_sample=True, temperature=0.6, top_p=0.9,
965
+ - suppress_tokens=None, eos_token_id=[tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|end|>")]
966
+ - )
967
+ -
968
+ - generate_ids = generate_ids[:, input_ids.shape[1]:]
969
+ - caption = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]
970
+ - caption = caption.strip()
971
+ - caption = re.sub(r'\s', ' ', caption)
972
+ -
973
+ - return prompt_str, caption
974
+ -
975
+ -# --- CaptionApp Class ---
976
+ -class CaptionApp(QMainWindow):
977
+ - def __init__(self):
978
+ - super().__init__()
979
+ - self.setWindowTitle("JoyCaption Alpha Two - Enhanced")
980
+ - self.setGeometry(100, 100, 1200, 850)
981
+ - self.setMinimumSize(1000, 750)
982
+ -
983
+ - self.clip_processor = None
984
+ - self.clip_model = None
985
+ - self.tokenizer = None
986
+ - self.text_model = None
987
+ - self.image_adapter = None
988
+ - self.models_loaded = False
989
+ -
990
+ - self.input_dir = None
991
+ - self.single_image_path = None
992
+ - self.selected_image_path = None
993
+ - self.image_files = []
994
+ -
995
+ - self.dark_mode = False
996
+ -
997
+ - self.central_widget = QWidget()
998
+ - self.setCentralWidget(self.central_widget)
999
+ - self.main_layout = QHBoxLayout(self.central_widget)
1000
+ -
1001
+ - self.initUI() # Call initUI
1002
+ - self.update_button_states()
1003
+ - self.apply_theme()
1004
+ -
1005
+ - def initUI(self):
1006
+ - # --- Left Panel ---
1007
+ - left_panel = QVBoxLayout()
1008
+ - left_panel.setSpacing(10)
1009
+ -
1010
+ - # Input directory selection
1011
+ - dir_layout = QHBoxLayout()
1012
+ - self.input_dir_button = QPushButton("Select Input Directory")
1013
+ - self.input_dir_button.setToolTip("Select a folder containing images to process in batch.")
1014
+ - self.input_dir_button.clicked.connect(self.select_input_directory)
1015
+ - dir_layout.addWidget(self.input_dir_button)
1016
+ - self.input_dir_label = QLabel("No directory selected")
1017
+ - self.input_dir_label.setWordWrap(True)
1018
+ - dir_layout.addWidget(self.input_dir_label, 1)
1019
+ - left_panel.addLayout(dir_layout)
1020
+ -
1021
+ - # Single image selection
1022
+ - single_img_layout = QHBoxLayout()
1023
+ - self.single_image_button = QPushButton("Select Single Image")
1024
+ - self.single_image_button.setToolTip("Select one image file to process.")
1025
+ - self.single_image_button.clicked.connect(self.select_single_image)
1026
+ - single_img_layout.addWidget(self.single_image_button)
1027
+ - self.single_image_label = QLabel("No image selected")
1028
+ - self.single_image_label.setWordWrap(True)
1029
+ - single_img_layout.addWidget(self.single_image_label, 1)
1030
+ - left_panel.addLayout(single_img_layout)
1031
+ -
1032
+ - # Caption Type
1033
+ - self.caption_type_combo = QComboBox()
1034
+ - self.caption_type_combo.addItems(CAPTION_TYPE_MAP.keys())
1035
+ - self.caption_type_combo.setCurrentText("Descriptive")
1036
+ - self.caption_type_combo.setToolTip("Choose the style or purpose of the caption.")
1037
+ - left_panel.addWidget(QLabel("Caption Type:"))
1038
+ - left_panel.addWidget(self.caption_type_combo)
1039
+ -
1040
+ - # Caption Length
1041
+ - self.caption_length_combo = QComboBox()
1042
+ - self.caption_length_combo.addItems(CAPTION_LENGTH_CHOICES)
1043
+ - self.caption_length_combo.setCurrentText("long")
1044
+ - self.caption_length_combo.setToolTip("Select desired caption length or word count.")
1045
+ - left_panel.addWidget(QLabel("Caption Length:"))
1046
+ - left_panel.addWidget(self.caption_length_combo)
1047
+ -
1048
+ - # Extra Options
1049
+ - left_panel.addWidget(QLabel("Extra Options:"))
1050
+ - self.extra_options_checkboxes = []
1051
+ - for option in EXTRA_OPTIONS_LIST:
1052
+ - checkbox = QCheckBox(option)
1053
+ - checkbox.setToolTip(option)
1054
+ - self.extra_options_checkboxes.append(checkbox)
1055
+ - left_panel.addWidget(checkbox)
1056
+ -
1057
+ - # Name Input
1058
+ - self.name_input_line = QLineEdit()
1059
+ - self.name_input_line.setPlaceholderText("e.g., 'the main character'")
1060
+ - self.name_input_line.setToolTip("If the first extra option is checked, this name will be used.")
1061
+ - left_panel.addWidget(QLabel("Person/Character Name (optional):"))
1062
+ - left_panel.addWidget(self.name_input_line)
1063
+ -
1064
+ - # Custom Prompt
1065
+ - self.custom_prompt_text = QTextEdit()
1066
+ - self.custom_prompt_text.setPlaceholderText("Overrides Caption Type/Length/Options if used.")
1067
+ - self.custom_prompt_text.setToolTip("Enter a full custom prompt here to ignore other settings.")
1068
+ - self.custom_prompt_text.setFixedHeight(80)
1069
+ - left_panel.addWidget(QLabel("Custom Prompt (optional):"))
1070
+ - left_panel.addWidget(self.custom_prompt_text)
1071
+ -
1072
+ - # Checkpoint Path
1073
+ - ckpt_layout = QHBoxLayout()
1074
+ - self.checkpoint_path_line = QLineEdit()
1075
+ - self.checkpoint_path_line.setToolTip("Path to the folder containing model files (clip_model.pt, etc.).")
1076
+ - ckpt_layout.addWidget(QLabel("Checkpoint Path:"))
1077
+ - ckpt_layout.addWidget(self.checkpoint_path_line)
1078
+ - self.browse_ckpt_button = QPushButton("...")
1079
+ - self.browse_ckpt_button.setToolTip("Browse for Checkpoint Directory")
1080
+ - self.browse_ckpt_button.clicked.connect(self.browse_checkpoint_path)
1081
+ - self.browse_ckpt_button.setMaximumWidth(30)
1082
+ - ckpt_layout.addWidget(self.browse_ckpt_button)
1083
+ - left_panel.addLayout(ckpt_layout)
1084
+ -
1085
+ - # Load Models Button
1086
+ - self.load_models_button = QPushButton("Load Models")
1087
+ - self.load_models_button.setToolTip("Load the AI models into memory (requires checkpoint path).")
1088
+ - self.load_models_button.clicked.connect(self.load_models_action)
1089
+ - left_panel.addWidget(self.load_models_button)
1090
+ -
1091
+ - # Run Buttons
1092
+ - self.run_button = QPushButton("Generate Captions for All Images in Directory")
1093
+ - self.run_button.setToolTip("Process all loaded images from the selected directory.")
1094
+ - self.run_button.clicked.connect(self.generate_captions_action)
1095
+ - left_panel.addWidget(self.run_button)
1096
+ -
1097
+ - self.caption_selected_button = QPushButton("Caption Selected Image from List")
1098
+ - self.caption_selected_button.setToolTip("Process the image currently highlighted in the list.")
1099
+ - self.caption_selected_button.clicked.connect(self.caption_selected_image_action)
1100
+ - left_panel.addWidget(self.caption_selected_button)
1101
+ -
1102
+ - self.caption_single_button = QPushButton("Caption Single Loaded Image")
1103
+ - self.caption_single_button.setToolTip("Process the image selected via 'Select Single Image'.")
1104
+ - self.caption_single_button.clicked.connect(self.caption_single_image_action)
1105
+ - left_panel.addWidget(self.caption_single_button)
1106
+ -
1107
+ - # Theme Toggle Button
1108
+ - self.toggle_theme_button = QPushButton("Toggle Dark Mode")
1109
+ - self.toggle_theme_button.setToolTip("Switch between light and dark themes.")
1110
+ - self.toggle_theme_button.clicked.connect(self.toggle_theme)
1111
+ - left_panel.addWidget(self.toggle_theme_button)
1112
+ -
1113
+ - left_panel.addStretch(1)
1114
+ -
1115
+ - # --- Right Panel ---
1116
+ - right_panel = QVBoxLayout()
1117
+ - right_panel.setSpacing(10)
1118
+ -
1119
+ - # List widget for images
1120
+ - right_panel.addWidget(QLabel("Images in Directory:"))
1121
+ - self.image_list_widget = QListWidget()
1122
+ - self.image_list_widget.setIconSize(self.image_list_widget.iconSize() * 2)
1123
+ - self.image_list_widget.itemClicked.connect(self.display_selected_image)
1124
+ - self.image_list_widget.setToolTip("Click an image to view it and enable 'Caption Selected Image'.")
1125
+ - right_panel.addWidget(self.image_list_widget, 1)
1126
+ -
1127
+ - # Label to display the selected image
1128
+ - right_panel.addWidget(QLabel("Selected Image Preview:"))
1129
+ - self.selected_image_label = QLabel("No image selected")
1130
+ - self.selected_image_label.setAlignment(Qt.AlignCenter)
1131
+ - self.selected_image_label.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
1132
+ - self.selected_image_label.setMinimumSize(300, 300)
1133
+ - self.selected_image_label.setStyleSheet("border: 1px solid gray;")
1134
+ - right_panel.addWidget(self.selected_image_label, 3)
1135
+ -
1136
+ - # Generated Caption Area
1137
+ - right_panel.addWidget(QLabel("Generated/Editable Caption:"))
1138
+ - self.generated_caption_text = QTextEdit()
1139
+ - self.generated_caption_text.setReadOnly(False)
1140
+ - self.generated_caption_text.setPlaceholderText("Generated caption will appear here. You can edit it before saving.")
1141
+ - self.generated_caption_text.setToolTip("The generated caption appears here. Edit and use 'Save Edited Caption'.")
1142
+ - right_panel.addWidget(self.generated_caption_text, 1)
1143
+ -
1144
+ - # Saving Options
1145
+ - self.overwrite_checkbox = QCheckBox("Overwrite existing captions")
1146
+ - self.overwrite_checkbox.setToolTip("If checked, automatically overwrites existing .txt files without asking.")
1147
+ - self.append_checkbox = QCheckBox("Append to existing captions")
1148
+ - self.append_checkbox.setToolTip("If checked, adds the new caption to the end of the existing .txt file.")
1149
+ -
1150
+ - save_options_layout = QHBoxLayout()
1151
+ - save_options_layout.addWidget(self.overwrite_checkbox)
1152
+ - save_options_layout.addWidget(self.append_checkbox)
1153
+ - save_options_layout.addStretch(1)
1154
+ - right_panel.addLayout(save_options_layout) # Add layout here
1155
+ -
1156
+ - self.append_checkbox.stateChanged.connect(
1157
+ - lambda state: self.overwrite_checkbox.setEnabled(state == Qt.Unchecked)
1158
+ - )
1159
+ -
1160
+ - # Save Edited Caption Button
1161
+ - self.save_caption_button = QPushButton("Save Edited Caption to File")
1162
+ - self.save_caption_button.setToolTip("Save the text currently in the box above to the corresponding .txt file using the selected options.")
1163
+ - self.save_caption_button.clicked.connect(self.save_edited_caption_action)
1164
+ - right_panel.addWidget(self.save_caption_button)
1165
+ -
1166
+ - # --- Main Layout Assembly ---
1167
+ - self.main_layout.addLayout(left_panel, 2)
1168
+ - self.main_layout.addLayout(right_panel, 5)
1169
+ -
1170
+ - # --- Status Bar and Progress Bar ---
1171
+ - self.status_bar = QStatusBar()
1172
+ - self.setStatusBar(self.status_bar)
1173
+ - self.progress_bar = QProgressBar()
1174
+ - self.status_bar.addPermanentWidget(self.progress_bar)
1175
+ - self.progress_bar.hide()
1176
+ - self.show_status("Ready.", 5000)
1177
+ -
1178
+ - def browse_checkpoint_path(self):
1179
+ - directory = QFileDialog.getExistingDirectory(self, "Select Checkpoint Directory")
1180
+ - if directory:
1181
+ - self.checkpoint_path_line.setText(directory)
1182
+ - self.update_button_states()
1183
+ -
1184
+ - def show_status(self, message, timeout=0):
1185
+ - self.status_bar.showMessage(message, timeout)
1186
+ - QApplication.processEvents()
1187
+ -
1188
+ - def update_button_states(self):
1189
+ - self.load_models_button.setEnabled(bool(self.checkpoint_path_line.text()))
1190
+ - models_ready = self.models_loaded
1191
+ - dir_selected = self.input_dir is not None and bool(self.image_files)
1192
+ - single_img_selected = self.single_image_path is not None
1193
+ - list_img_selected = self.selected_image_path is not None
1194
+ - caption_present = bool(self.generated_caption_text.toPlainText().strip())
1195
+ -
1196
+ - self.run_button.setEnabled(models_ready and dir_selected)
1197
+ - self.caption_selected_button.setEnabled(models_ready and list_img_selected)
1198
+ - self.caption_single_button.setEnabled(models_ready and single_img_selected)
1199
+ - self.save_caption_button.setEnabled(caption_present and (list_img_selected or single_img_selected))
1200
+ -
1201
+ - def apply_theme(self):
1202
+ - dark_stylesheet = """
1203
+ - QMainWindow, QWidget { background-color: #2E2E2E; color: #FFFFFF; font-family: Arial, sans-serif; }
1204
+ - QPushButton { background-color: #3A3A3A; color: #FFFFFF; border: 1px solid #555555; padding: 5px; min-height: 20px; }
1205
+ - QPushButton:hover { background-color: #555555; }
1206
+ - QPushButton:disabled { background-color: #454545; color: #888888; }
1207
+ - QLabel { color: #FFFFFF; }
1208
+ - QLineEdit, QTextEdit, QComboBox { background-color: #3A3A3A; color: #FFFFFF; border: 1px solid #555555; padding: 4px; }
1209
+ - QLineEdit:disabled, QTextEdit:disabled, QComboBox:disabled { background-color: #454545; color: #888888; }
1210
+ - QListWidget { background-color: #3A3A3A; color: #FFFFFF; border: 1px solid #555555; alternate-background-color: #424242; }
1211
+ - QCheckBox { color: #FFFFFF; spacing: 5px; }
1212
+ - QCheckBox::indicator { width: 13px; height: 13px; }
1213
+ - QStatusBar { color: #FFFFFF; } QStatusBar::item { border: none; }
1214
+ - QProgressBar { border: 1px solid #555555; text-align: center; color: #FFFFFF; background-color: #3A3A3A; }
1215
+ - QProgressBar::chunk { background-color: #007ADF; width: 10px; margin: 0.5px; }
1216
+ - QToolTip { background-color: #464646; color: #FFFFFF; border: 1px solid #555555; padding: 4px; }
1217
+ - QTextEdit { placeholderTextColor: gray; } QLineEdit { placeholderTextColor: gray; }
1218
+ - """
1219
+ - if self.dark_mode: self.setStyleSheet(dark_stylesheet)
1220
+ - else: self.setStyleSheet("")
1221
+ -
1222
+ - placeholder_style = "QTextEdit { placeholderTextColor: gray; } QLineEdit { placeholderTextColor: gray; }"
1223
+ - current_style = self.styleSheet()
1224
+ - if self.dark_mode:
1225
+ - if "placeholderTextColor" not in current_style: self.setStyleSheet(current_style placeholder_style)
1226
+ - else: self.setStyleSheet(current_style.replace(placeholder_style, ""))
1227
+ -
1228
+ - def toggle_theme(self):
1229
+ - self.dark_mode = not self.dark_mode
1230
+ - self.apply_theme()
1231
+ -
1232
+ - def select_input_directory(self):
1233
+ - directory = QFileDialog.getExistingDirectory(self, "Select Input Directory")
1234
+ - if directory:
1235
+ - self.input_dir = Path(directory)
1236
+ - self.input_dir_label.setText(str(self.input_dir))
1237
+ - self.single_image_path = None; self.single_image_label.setText("No image selected")
1238
+ - self.selected_image_path = None; self.selected_image_label.setText("No image selected")
1239
+ - self.generated_caption_text.clear()
1240
+ - self.load_images()
1241
+ - self.show_status(f"Selected directory: {self.input_dir.name}", 5000)
1242
+ - else:
1243
+ - self.input_dir_label.setText("No directory selected"); self.input_dir = None
1244
+ - self.image_list_widget.clear(); self.image_files = []
1245
+ - self.show_status("Directory selection cancelled.", 3000)
1246
+ - self.update_button_states()
1247
+ -
1248
+ - def select_single_image(self):
1249
+ - file_filter = "Image Files (*.jpg *.jpeg *.png *.bmp *.gif *.tiff *.webp)"
1250
+ - file_path, _ = QFileDialog.getOpenFileName(self, "Select Single Image", "", file_filter)
1251
+ - if file_path:
1252
+ - self.single_image_path = Path(file_path)
1253
+ - self.single_image_label.setText(str(self.single_image_path.name))
1254
+ - self.input_dir = None; self.input_dir_label.setText("No directory selected")
1255
+ - self.image_list_widget.clear(); self.image_files = []
1256
+ - self.selected_image_path = None
1257
+ - self.display_image(self.single_image_path)
1258
+ - self.show_status(f"Selected single image: {self.single_image_path.name}", 5000)
1259
+ - else:
1260
+ - self.single_image_label.setText("No image selected"); self.single_image_path = None
1261
+ - self.show_status("Single image selection cancelled.", 3000)
1262
+ - self.update_button_states()
1263
+ -
1264
+ - def load_images(self):
1265
+ - if not self.input_dir: return
1266
+ - self.show_status(f"Loading images from {self.input_dir.name}...")
1267
+ - image_extensions = [".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff", ".webp"]
1268
+ - try:
1269
+ - self.image_files = sorted([f for f in self.input_dir.iterdir() if f.is_file() and f.suffix.lower() in image_extensions])
1270
+ - except Exception as e:
1271
+ - QMessageBox.critical(self, "Directory Error", f"Could not read directory contents:\n{e}")
1272
+ - self.show_status(f"Error reading directory {self.input_dir.name}", 5000)
1273
+ - self.image_files = []; self.input_dir = None; self.input_dir_label.setText("Error reading directory")
1274
+ -
1275
+ - self.image_list_widget.clear()
1276
+ - if not self.image_files:
1277
+ - if self.input_dir:
1278
+ - QMessageBox.warning(self, "No Images", "No supported image files found.")
1279
+ - self.show_status("No images found in directory.", 3000)
1280
+ - self.update_button_states()
1281
+ - return
1282
+ -
1283
+ - thumb_size = 100
1284
+ - for image_path in self.image_files:
1285
+ - item = QListWidgetItem(str(image_path.name))
1286
+ - try:
1287
+ - pixmap = QPixmap(str(image_path))
1288
+ - if not pixmap.isNull():
1289
+ - scaled_pixmap = pixmap.scaled(thumb_size, thumb_size, Qt.KeepAspectRatio, Qt.SmoothTransformation)
1290
+ - item.setIcon(QIcon(scaled_pixmap))
1291
+ - else: print(f"Warning: QPixmap is null for {image_path.name}")
1292
+ - except Exception as e: print(f"Warning: Could not create thumbnail for {image_path.name}: {e}")
1293
+ - self.image_list_widget.addItem(item)
1294
+ -
1295
+ - self.show_status(f"Loaded {len(self.image_files)} images.", 5000)
1296
+ - self.update_button_states()
1297
+ -
1298
+ - def display_selected_image(self, item):
1299
+ - if not self.input_dir or not item: return
1300
+ - try:
1301
+ - image_name = item.text()
1302
+ - image_path = self.input_dir / image_name
1303
+ - if not image_path.exists():
1304
+ - QMessageBox.warning(self, "File Not Found", f"Image file '{image_name}' no longer exists.")
1305
+ - self.selected_image_label.setText("File not found")
1306
+ - self.selected_image_label.setPixmap(QPixmap())
1307
+ - self.generated_caption_text.clear()
1308
+ - self.selected_image_path = None
1309
+ - return
1310
+ -
1311
+ - self.selected_image_path = image_path
1312
+ - self.single_image_path = None
1313
+ - self.single_image_label.setText("No image selected")
1314
+ - self.display_image(image_path)
1315
+ - caption_file_path = image_path.with_suffix('.txt')
1316
+ - if caption_file_path.exists():
1317
+ - try:
1318
+ - with open(caption_file_path, 'r', encoding='utf-8') as f:
1319
+ - caption_content = f.read()
1320
+ - self.generated_caption_text.setText(caption_content)
1321
+ - status_message = f"Displayed {image_name} and loaded existing caption."
1322
+ - except Exception as e:
1323
+ - print(f"Warning: Could not read caption file {caption_file_path.name}: {e}")
1324
+ - # Keep caption box clear or show error placeholder
1325
+ - self.generated_caption_text.setPlaceholderText(f"Error reading caption file for {image_name}.")
1326
+ - status_message = f"Displayed {image_name}, but failed to load caption file."
1327
+ - else:
1328
+ - # Keep caption box clear (already done by display_image)
1329
+ - self.generated_caption_text.setPlaceholderText("Generate or edit caption here.")
1330
+ - status_message = f"Displayed {image_name}. No existing caption found."
1331
+ - self.show_status(f"Selected {image_name} from list.", 4000)
1332
+ - except Exception as e:
1333
+ - self.selected_image_label.setText("Error loading preview")
1334
+ - self.selected_image_path = None
1335
+ - QMessageBox.warning(self, "Preview Error", f"Could not load preview for {item.text()}: {e}")
1336
+ - self.show_status(f"Error loading preview for {item.text()}", 4000)
1337
+ - self.update_button_states()
1338
+ -
1339
+ - def display_image(self, image_path):
1340
+ - try:
1341
+ - pixmap = QPixmap(str(image_path))
1342
+ - if not pixmap.isNull():
1343
+ - self.scale_and_set_pixmap(pixmap)
1344
+ - self.generated_caption_text.clear()
1345
+ - else:
1346
+ - self.selected_image_label.setText(f"Cannot display image:\n{image_path.name}")
1347
+ - self.selected_image_label.setPixmap(QPixmap())
1348
+ - except Exception as e:
1349
+ - self.selected_image_label.setText(f"Error loading preview:\n{image_path.name}")
1350
+ - self.selected_image_label.setPixmap(QPixmap())
1351
+ - print(f"Error displaying image {image_path}: {e}")
1352
+ - self.show_status(f"Error displaying image {image_path.name}", 4000)
1353
+ - self.update_button_states()
1354
+ -
1355
+ - def scale_and_set_pixmap(self, pixmap):
1356
+ - if not pixmap or pixmap.isNull():
1357
+ - self.selected_image_label.clear()
1358
+ - self.selected_image_label.setText("No image selected")
1359
+ - return
1360
+ - label_size = self.selected_image_label.contentsRect().size()
1361
+ - scaled_pixmap = pixmap.scaled(label_size * self.devicePixelRatioF(), Qt.KeepAspectRatio, Qt.SmoothTransformation)
1362
+ - self.selected_image_label.setPixmap(scaled_pixmap)
1363
+ -
1364
+ - def load_models_action(self):
1365
+ - checkpoint_path_str = self.checkpoint_path_line.text()
1366
+ - if not checkpoint_path_str: QMessageBox.warning(self, "Checkpoint Error", "Please specify the checkpoint path."); return
1367
+ - checkpoint_path = Path(checkpoint_path_str)
1368
+ - if not checkpoint_path.exists() or not checkpoint_path.is_dir():
1369
+ - QMessageBox.warning(self, "Checkpoint Error", f"Checkpoint path does not exist or is not a directory:\n{checkpoint_path}"); return
1370
+ -
1371
+ - self.show_status("Loading models... This might take a while.", 0)
1372
+ - self.progress_bar.setRange(0, 0); self.progress_bar.show(); QApplication.processEvents()
1373
+ - try:
1374
+ - (self.clip_processor, self.clip_model, self.tokenizer, self.text_model, self.image_adapter) = load_models(checkpoint_path, status_callback=self.show_status)
1375
+ - self.models_loaded = True
1376
+ - QMessageBox.information(self, "Models Loaded", "Models have been loaded successfully.")
1377
+ - self.show_status("Models loaded successfully. Ready to caption.", 5000)
1378
+ - except Exception as e:
1379
+ - self.models_loaded = False
1380
+ - QMessageBox.critical(self, "Model Loading Error", f"An error occurred while loading models:\n{e}\n\nCheck console for details.")
1381
+ - self.show_status(f"Model loading failed. Check console.", 0)
1382
+ - print(f"--- Model Loading Error ---"); import traceback; traceback.print_exc(); print(f"--- End Error Traceback ---")
1383
+ - finally:
1384
+ - self.progress_bar.hide(); self.progress_bar.setRange(0, 100); self.update_button_states()
1385
+ -
1386
+ - def collect_parameters(self):
1387
+ - return (self.caption_type_combo.currentText(), self.caption_length_combo.currentText(),
1388
+ - [cb.text() for cb in self.extra_options_checkboxes if cb.isChecked()],
1389
+ - self.name_input_line.text(), self.custom_prompt_text.toPlainText())
1390
+ -
1391
+ - def _confirm_overwrite(self, file_path: Path) -> bool:
1392
+ - if file_path.exists():
1393
+ - reply = QMessageBox.question(self, 'Confirm Overwrite', f"Caption file '{file_path.name}' already exists.\nOverwrite?", QMessageBox.Yes | QMessageBox.No, QMessageBox.No)
1394
+ - return reply == QMessageBox.Yes
1395
+ - return True
1396
+ -
1397
+ - def _save_caption_to_file(self, image_path: Path, caption: str) -> bool:
1398
+ - if not image_path: self.show_status("Error: No image path associated.", 5000); return False
1399
+ - caption_file = image_path.with_suffix('.txt')
1400
+ - mode = 'a' if self.append_checkbox.isChecked() else 'w'
1401
+ - prefix = '\n' if mode == 'a' and caption_file.exists() and caption_file.stat().st_size > 0 else ''
1402
+ -
1403
+ - if mode == 'w' and caption_file.exists() and not self.overwrite_checkbox.isChecked():
1404
+ - if not self._confirm_overwrite(caption_file):
1405
+ - self.show_status(f"Skipped saving {image_path.name}.", 3000); return False
1406
+ - try:
1407
+ - with open(caption_file, mode, encoding='utf-8') as f: f.write(f"{prefix}{caption}")
1408
+ - self.show_status(f"Caption {'appended to' if mode == 'a' else 'saved to'} {caption_file.name}", 4000); return True
1409
+ - except Exception as e:
1410
+ - QMessageBox.critical(self, "Save Error", f"Error saving caption for {image_path.name}:\n{e}")
1411
+ - self.show_status(f"Error saving caption for {image_path.name}", 5000); print(f"Error saving caption to {caption_file}: {e}"); return False
1412
+ -
1413
+ - def _run_caption_generation(self, image_path: Path):
1414
+ - if not self.models_loaded: QMessageBox.warning(self, "Models Not Loaded", "Please load models first."); return None
1415
+ - if not image_path or not image_path.exists():
1416
+ - QMessageBox.warning(self, "Image Not Found", f"Image file does not exist:\n{image_path}")
1417
+ - self.show_status(f"Image not found: {image_path.name if image_path else 'None'}", 5000); return None
1418
+ -
1419
+ - self.show_status(f"Processing: {image_path.name}...", 0); QApplication.processEvents()
1420
+ - params = self.collect_parameters()
1421
+ - try: input_image = Image.open(image_path)
1422
+ - except Exception as e:
1423
+ - QMessageBox.critical(self, "Image Open Error", f"Failed to open {image_path.name}:\n{e}")
1424
+ - self.show_status(f"Error opening {image_path.name}", 5000); print(f"Error opening image {image_path}: {e}"); return None
1425
+ - try:
1426
+ - prompt_str, caption = generate_caption(input_image, *params, self.clip_model, self.tokenizer, self.text_model, self.image_adapter)
1427
+ - current_viewed_path = self.selected_image_path or self.single_image_path
1428
+ - if image_path == current_viewed_path: self.generated_caption_text.setText(caption)
1429
+ - if self._save_caption_to_file(image_path, caption): print(f"Caption generated and saved for {image_path.name}")
1430
+ - else: print(f"Caption generated but NOT saved for {image_path.name}")
1431
+ - return caption
1432
+ - except Exception as e:
1433
+ - QMessageBox.critical(self, "Processing Error", f"Failed to process {image_path.name}:\n{e}\n\nCheck console.")
1434
+ - self.show_status(f"Error processing {image_path.name}. Check console.", 0)
1435
+ - print(f"--- Processing Error for {image_path.name} ---"); import traceback; traceback.print_exc(); print(f"--- End Error Traceback ---")
1436
+ - current_viewed_path = self.selected_image_path or self.single_image_path
1437
+ - if image_path == current_viewed_path: self.generated_caption_text.setText(f"Error generating caption. See console.")
1438
+ - return None
1439
+ - finally: QApplication.processEvents()
1440
+ -
1441
+ - def generate_captions_action(self):
1442
+ - if not self.input_dir or not self.image_files: QMessageBox.warning(self, "No Images", "Select directory with images first."); return
1443
+ - if not self.models_loaded: QMessageBox.warning(self, "Models Not Loaded", "Load models first."); return
1444
+ -
1445
+ - num_images = len(self.image_files)
1446
+ - self.progress_bar.setRange(0, num_images); self.progress_bar.setValue(0); self.progress_bar.show()
1447
+ - self.show_status(f"Starting batch captioning for {num_images} images...", 0)
1448
+ -
1449
+ - processed_count, error_count, skipped_explicitly = 0, 0, 0
1450
+ - original_overwrite_state = self.overwrite_checkbox.isChecked() # Remember original state
1451
+ - ask_all = False # Flag to check if user agreed to overwrite all
1452
+ -
1453
+ - # Pre-check for overwrites if needed
1454
+ - files_to_confirm = []
1455
+ - if not self.overwrite_checkbox.isChecked() and not self.append_checkbox.isChecked():
1456
+ - files_to_confirm = [img.with_suffix('.txt').name for img in self.image_files if img.with_suffix('.txt').exists()]
1457
+ -
1458
+ - if files_to_confirm:
1459
+ - reply = QMessageBox.question(self, 'Confirm Overwrite Multiple', f"{len(files_to_confirm)} existing caption file(s) found.\nOverwrite ALL existing files?", QMessageBox.Yes | QMessageBox.No | QMessageBox.Cancel, QMessageBox.Cancel)
1460
+ - if reply == QMessageBox.Cancel: self.show_status("Batch cancelled.", 3000); self.progress_bar.hide(); return
1461
+ - elif reply == QMessageBox.Yes: ask_all = True; self.overwrite_checkbox.setChecked(True) # Temporarily check it
1462
+ -
1463
+ - # Process images
1464
+ - for i, image_path in enumerate(self.image_files):
1465
+ - # Run generation. _save_caption_to_file handles individual confirmation if ask_all is False
1466
+ - caption_result = self._run_caption_generation(image_path)
1467
+ -
1468
+ - # Track results (Approximate - relies on _save reporting skips)
1469
+ - if caption_result is not None:
1470
+ - processed_count = 1
1471
+ - else:
1472
+ - # If None, assume error unless status bar indicates skip (imperfect)
1473
+ - if "Skipped saving" not in self.status_bar.currentMessage():
1474
+ - error_count = 1
1475
+ - # No reliable way to count skips here without modifying _save return value
1476
+ -
1477
+ - self.progress_bar.setValue(i 1)
1478
+ - QApplication.processEvents()
1479
+ -
1480
+ - # Restore overwrite checkbox state if changed
1481
+ - if ask_all: self.overwrite_checkbox.setChecked(original_overwrite_state)
1482
+ -
1483
+ - self.progress_bar.hide()
1484
+ - final_message = f"Batch finished. {processed_count} captions generated/saved."
1485
+ - if error_count > 0: final_message = f" {error_count} errors."
1486
+ - # Cannot reliably report skips here
1487
+ - QMessageBox.information(self, "Batch Complete", final_message)
1488
+ - self.show_status(final_message, 10000)
1489
+ - self.update_button_states()
1490
+ -
1491
+ - def caption_selected_image_action(self):
1492
+ - if not self.selected_image_path: QMessageBox.warning(self, "No Image Selected", "Select image from list first."); return
1493
+ - self._run_caption_generation(self.selected_image_path); self.update_button_states()
1494
+ -
1495
+ - def caption_single_image_action(self):
1496
+ - if not self.single_image_path: QMessageBox.warning(self, "No Image Selected", "Select single image first."); return
1497
+ - self._run_caption_generation(self.single_image_path); self.update_button_states()
1498
+ -
1499
+ - def save_edited_caption_action(self):
1500
+ - edited_caption = self.generated_caption_text.toPlainText().strip()
1501
+ - if not edited_caption: QMessageBox.warning(self, "Empty Caption", "Caption text is empty."); return
1502
+ - current_image_path = self.selected_image_path or self.single_image_path
1503
+ - if not current_image_path: QMessageBox.warning(self, "No Associated Image", "Select image first."); return
1504
+ - self._save_caption_to_file(current_image_path, edited_caption)
1505
+ -
1506
+ - def resizeEvent(self, event):
1507
+ - super().resizeEvent(event)
1508
+ - current_path = None
1509
+ - if self.selected_image_label.pixmap() and not self.selected_image_label.pixmap().isNull():
1510
+ - current_path = self.selected_image_path or self.single_image_path
1511
+ - if current_path and current_path.exists():
1512
+ - try:
1513
+ - pixmap = QPixmap(str(current_path))
1514
+ - if not pixmap.isNull(): self.scale_and_set_pixmap(pixmap)
1515
+ - except Exception as e: print(f"Error reloading pixmap on resize for {current_path}: {e}")
1516
+ - elif not self.selected_image_label.text() or self.selected_image_label.text().startswith(("Cannot", "Error", "No image")):
1517
+ - self.selected_image_label.clear(); self.selected_image_label.setText("No image selected")
1518
+ -
1519
+ -
1520
+ -if __name__ == "__main__":
1521
+ - QApplication.setAttribute(Qt.AA_EnableHighDpiScaling, True) # Optional
1522
+ - QApplication.setAttribute(Qt.AA_UseHighDpiPixmaps, True) # Optional
1523
+ - app = QApplication(sys.argv)
1524
+ - app.setStyle("Fusion") # Optional
1525
+ - window = CaptionApp()
1526
+ - window.show()
1527
  import sys
1528
  import os
1529
  import torch
 
1640
 
1641
  CAPTION_LENGTH_CHOICES = (
1642
  ["any", "very short", "short", "medium-length", "long", "very long"]
1643
+ [str(i) for i in range(20, 261, 10)]
1644
  )
1645
 
1646
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
 
1709
 
1710
  if self.pos_emb is not None:
1711
  assert x.shape[-2:] == self.pos_emb.shape
1712
+ x = x self.pos_emb
1713
 
1714
  x = self.linear1(x)
1715
  x = self.activation(x)
 
1833
  else: raise ValueError(f"Invalid caption length: {length}")
1834
 
1835
  prompt_str = CAPTION_TYPE_MAP[caption_type][map_idx]
1836
+ if len(extra_options) > 0: prompt_str = " " " ".join(extra_options)
1837
  prompt_str = prompt_str.format(name=name_input, length=caption_length, word_count=caption_length)
1838
 
1839
  print(f"Prompt: {prompt_str}")
 
1861
  if hasattr(tokenizer, "apply_chat_template"):
1862
  convo_string = tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
1863
  else:
1864
+ convo_string = ("<|system|>\n" convo[0]["content"] "\n<|end|>\n<|user|>\n" convo[1]["content"] "\n<|end|>\n")
1865
  assert isinstance(convo_string, str)
1866
 
1867
  convo_tokens = tokenizer.encode(convo_string, return_tensors="pt", add_special_tokens=False, truncation=False).to(device)
 
1873
  end_token_id = tokenizer.convert_tokens_to_ids("<|end|>")
1874
  if end_token_id is None: raise ValueError("Tokenizer missing '<|end|>' token.")
1875
  end_token_indices = (convo_tokens == end_token_id).nonzero(as_tuple=True)[0].tolist()
1876
+ preamble_len = end_token_indices[0] 1 if len(end_token_indices) >= 1 else 0
1877
 
1878
  convo_embeds = text_model.model.embed_tokens(convo_tokens.unsqueeze(0).to(device))
1879
  input_embeds = torch.cat([
 
1900
  generate_ids = generate_ids[:, input_ids.shape[1]:]
1901
  caption = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]
1902
  caption = caption.strip()
1903
+ caption = re.sub(r'\s', ' ', caption)
1904
 
1905
  return prompt_str, caption
1906
 
 
1924
  self.selected_image_path = None
1925
  self.image_files = []
1926
 
1927
+ self.dark_mode = False
1928
 
1929
  self.central_widget = QWidget()
1930
  self.setCentralWidget(self.central_widget)
1931
  self.main_layout = QHBoxLayout(self.central_widget)
1932
 
1933
+ self.initUI() # Call initUI
 
 
 
1934
  self.update_button_states()
1935
+ self.apply_theme()
1936
 
1937
 
1938
  def initUI(self):
 
2005
  # Checkpoint Path
2006
  ckpt_layout = QHBoxLayout()
2007
  self.checkpoint_path_line = QLineEdit()
 
2008
  self.checkpoint_path_line.setToolTip("Path to the folder containing model files (clip_model.pt, etc.).")
2009
  ckpt_layout.addWidget(QLabel("Checkpoint Path:"))
2010
  ckpt_layout.addWidget(self.checkpoint_path_line)
 
2232
 
2233
  CAPTION_LENGTH_CHOICES = (
2234
  ["any", "very short", "short", "medium-length", "long", "very long"]
2235
+ [str(i) for i in range(20, 261, 10)]
2236
  )
2237
 
2238
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
 
2301
 
2302
  if self.pos_emb is not None:
2303
  assert x.shape[-2:] == self.pos_emb.shape
2304
+ x = x self.pos_emb
2305
 
2306
  x = self.linear1(x)
2307
  x = self.activation(x)
 
2425
  else: raise ValueError(f"Invalid caption length: {length}")
2426
 
2427
  prompt_str = CAPTION_TYPE_MAP[caption_type][map_idx]
2428
+ if len(extra_options) > 0: prompt_str = " " " ".join(extra_options)
2429
  prompt_str = prompt_str.format(name=name_input, length=caption_length, word_count=caption_length)
2430
 
2431
  print(f"Prompt: {prompt_str}")
 
2453
  if hasattr(tokenizer, "apply_chat_template"):
2454
  convo_string = tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
2455
  else:
2456
+ convo_string = ("<|system|>\n" convo[0]["content"] "\n<|end|>\n<|user|>\n" convo[1]["content"] "\n<|end|>\n")
2457
  assert isinstance(convo_string, str)
2458
 
2459
  convo_tokens = tokenizer.encode(convo_string, return_tensors="pt", add_special_tokens=False, truncation=False).to(device)
 
2465
  end_token_id = tokenizer.convert_tokens_to_ids("<|end|>")
2466
  if end_token_id is None: raise ValueError("Tokenizer missing '<|end|>' token.")
2467
  end_token_indices = (convo_tokens == end_token_id).nonzero(as_tuple=True)[0].tolist()
2468
+ preamble_len = end_token_indices[0] 1 if len(end_token_indices) >= 1 else 0
2469
 
2470
  convo_embeds = text_model.model.embed_tokens(convo_tokens.unsqueeze(0).to(device))
2471
  input_embeds = torch.cat([
 
2492
  generate_ids = generate_ids[:, input_ids.shape[1]:]
2493
  caption = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]
2494
  caption = caption.strip()
2495
+ caption = re.sub(r'\s', ' ', caption)
2496
 
2497
  return prompt_str, caption
2498
 
 
2523
  self.main_layout = QHBoxLayout(self.central_widget)
2524
 
2525
  self.initUI() # Call initUI
 
 
 
 
 
 
 
 
 
 
 
 
2526
  self.update_button_states()
2527
  self.apply_theme()
2528
 
 
2746
  placeholder_style = "QTextEdit { placeholderTextColor: gray; } QLineEdit { placeholderTextColor: gray; }"
2747
  current_style = self.styleSheet()
2748
  if self.dark_mode:
2749
+ if "placeholderTextColor" not in current_style: self.setStyleSheet(current_style placeholder_style)
2750
  else: self.setStyleSheet(current_style.replace(placeholder_style, ""))
2751
 
2752
  def toggle_theme(self):
 
2991
 
2992
  # Track results (Approximate - relies on _save reporting skips)
2993
  if caption_result is not None:
2994
+ processed_count = 1
2995
  else:
2996
  # If None, assume error unless status bar indicates skip (imperfect)
2997
  if "Skipped saving" not in self.status_bar.currentMessage():
2998
+ error_count = 1
2999
  # No reliable way to count skips here without modifying _save return value
3000
 
3001
+ self.progress_bar.setValue(i 1)
3002
  QApplication.processEvents()
3003
 
3004
  # Restore overwrite checkbox state if changed
 
3006
 
3007
  self.progress_bar.hide()
3008
  final_message = f"Batch finished. {processed_count} captions generated/saved."
3009
+ if error_count > 0: final_message = f" {error_count} errors."
3010
  # Cannot reliably report skips here
3011
  QMessageBox.information(self, "Batch Complete", final_message)
3012
  self.show_status(final_message, 10000)
 
3048
  app.setStyle("Fusion") # Optional
3049
  window = CaptionApp()
3050
  window.show()
3051
+ sys.exit(app.exec_())
3052
+