Ole1 commited on
Commit
efe31af
·
verified ·
1 Parent(s): cf37518

Upload Run_gui.py

Browse files
Files changed (1) hide show
  1. Run_gui.py +1538 -0
Run_gui.py ADDED
@@ -0,0 +1,1538 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import torch
4
+ from torch import nn
5
+ from transformers import (
6
+ AutoModel,
7
+ AutoProcessor,
8
+ AutoTokenizer,
9
+ PreTrainedTokenizer,
10
+ PreTrainedTokenizerFast,
11
+ AutoModelForCausalLM,
12
+ BitsAndBytesConfig,
13
+ )
14
+ from PIL import Image
15
+ import torchvision.transforms.functional as TVF
16
+ import contextlib
17
+ from typing import Union, List
18
+ from pathlib import Path
19
+ import re
20
+
21
+ from PyQt5.QtWidgets import (
22
+ QApplication,
23
+ QWidget,
24
+ QLabel,
25
+ QPushButton,
26
+ QFileDialog,
27
+ QLineEdit,
28
+ QTextEdit,
29
+ QComboBox,
30
+ QVBoxLayout,
31
+ QHBoxLayout,
32
+ QCheckBox,
33
+ QListWidget,
34
+ QListWidgetItem,
35
+ QMessageBox,
36
+ QSizePolicy,
37
+ QStatusBar,
38
+ QProgressBar,
39
+ QMainWindow,
40
+ )
41
+ from PyQt5.QtGui import QPixmap, QIcon
42
+ from PyQt5.QtCore import Qt, QTimer
43
+
44
+ # --- Constants and Mappings ---
45
+ CLIP_PATH = "google/siglip-so400m-patch14-384"
46
+ CHECKPOINT_PATH = Path("cgrkzexw-599808")
47
+ CAPTION_TYPE_MAP = {
48
+ "Descriptive": [
49
+ "Write a descriptive caption for this image in a formal tone.",
50
+ "Write a descriptive caption for this image in a formal tone within {word_count} words.",
51
+ "Write a {length} descriptive caption for this image in a formal tone.",
52
+ ],
53
+ "Descriptive (Informal)": [
54
+ "Write a descriptive caption for this image in a casual tone.",
55
+ "Write a descriptive caption for this image in a casual tone within {word_count} words.",
56
+ "Write a {length} descriptive caption for this image in a casual tone.",
57
+ ],
58
+ "Training Prompt": [
59
+ "Write a stable diffusion prompt for this image.",
60
+ "Write a stable diffusion prompt for this image within {word_count} words.",
61
+ "Write a {length} stable diffusion prompt for this image.",
62
+ ],
63
+ "MidJourney": [
64
+ "Write a MidJourney prompt for this image.",
65
+ "Write a MidJourney prompt for this image within {word_count} words.",
66
+ "Write a {length} MidJourney prompt for this image.",
67
+ ],
68
+ "Booru tag list": [
69
+ "Write a list of Booru tags for this image.",
70
+ "Write a list of Booru tags for this image within {word_count} words.",
71
+ "Write a {length} list of Booru tags for this image.",
72
+ ],
73
+ "Booru-like tag list": [
74
+ "Write a list of Booru-like tags for this image.",
75
+ "Write a list of Booru-like tags for this image within {word_count} words.",
76
+ "Write a {length} list of Booru-like tags for this image.",
77
+ ],
78
+ "Art Critic": [
79
+ "Analyze this image like an art critic would with information about its composition, style, symbolism, the use of color, light, any artistic movement it might belong to, etc.",
80
+ "Analyze this image like an art critic would with information about its composition, style, symbolism, the use of color, light, any artistic movement it might belong to, etc. Keep it within {word_count} words.",
81
+ "Analyze this image like an art critic would with information about its composition, style, symbolism, the use of color, light, any artistic movement it might belong to, etc. Keep it {length}.",
82
+ ],
83
+ "Product Listing": [
84
+ "Write a caption for this image as though it were a product listing.",
85
+ "Write a caption for this image as though it were a product listing. Keep it under {word_count} words.",
86
+ "Write a {length} caption for this image as though it were a product listing.",
87
+ ],
88
+ "Social Media Post": [
89
+ "Write a caption for this image as if it were being used for a social media post.",
90
+ "Write a caption for this image as if it were being used for a social media post. Limit the caption to {word_count} words.",
91
+ "Write a {length} caption for this image as if it were being used for a social media post.",
92
+ ],
93
+ }
94
+
95
+ EXTRA_OPTIONS_LIST = [
96
+ "If there is a person/character in the image you must refer to them as {name}.",
97
+ "Do NOT include information about people/characters that cannot be changed (like ethnicity, gender, etc), but do still include changeable attributes (like hair style).",
98
+ "Include information about lighting.",
99
+ "Include information about camera angle.",
100
+ "Include information about whether there is a watermark or not.",
101
+ "Include information about whether there are JPEG artifacts or not.",
102
+ "If it is a photo you MUST include information about what camera was likely used and details such as aperture, shutter speed, ISO, etc.",
103
+ "Do NOT include anything sexual; keep it PG.",
104
+ "Do NOT mention the image's resolution.",
105
+ "You MUST include information about the subjective aesthetic quality of the image from low to very high.",
106
+ "Include information on the image's composition style, such as leading lines, rule of thirds, or symmetry.",
107
+ "Do NOT mention any text that is in the image.",
108
+ "Specify the depth of field and whether the background is in focus or blurred.",
109
+ "If applicable, mention the likely use of artificial or natural lighting sources.",
110
+ "Do NOT use any ambiguous language.",
111
+ "Include whether the image is sfw, suggestive, or nsfw.",
112
+ "ONLY describe the most important elements of the image.",
113
+ ]
114
+
115
+ CAPTION_LENGTH_CHOICES = (
116
+ ["any", "very short", "short", "medium-length", "long", "very long"]
117
+ + [str(i) for i in range(20, 261, 10)]
118
+ )
119
+
120
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
121
+
122
+ # --- Device and Autocast Setup ---
123
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
124
+ if device.type == "cuda":
125
+ torch_dtype = torch.bfloat16
126
+ else:
127
+ torch_dtype = torch.float32
128
+
129
+ if device.type == "cuda":
130
+ autocast = lambda: torch.amp.autocast(device_type='cuda', dtype=torch_dtype)
131
+ else:
132
+ autocast = contextlib.nullcontext
133
+
134
+ # --- ImageAdapter Class ---
135
+ class ImageAdapter(nn.Module):
136
+ def __init__(
137
+ self,
138
+ input_features: int,
139
+ output_features: int,
140
+ ln1: bool,
141
+ pos_emb: bool,
142
+ num_image_tokens: int,
143
+ deep_extract: bool,
144
+ ):
145
+ super().__init__()
146
+ self.deep_extract = deep_extract
147
+
148
+ if self.deep_extract:
149
+ input_features = input_features * 5
150
+
151
+ self.linear1 = nn.Linear(input_features, output_features)
152
+ self.activation = nn.GELU()
153
+ self.linear2 = nn.Linear(output_features, output_features)
154
+ self.ln1 = nn.Identity() if not ln1 else nn.LayerNorm(input_features)
155
+ self.pos_emb = (
156
+ None if not pos_emb else nn.Parameter(torch.zeros(num_image_tokens, input_features))
157
+ )
158
+
159
+ # Other tokens (<|image_start|>, <|image_end|>, <|eot_id|>)
160
+ self.other_tokens = nn.Embedding(3, output_features)
161
+ self.other_tokens.weight.data.normal_(
162
+ mean=0.0, std=0.02
163
+ )
164
+
165
+ def forward(self, vision_outputs: torch.Tensor):
166
+ if self.deep_extract:
167
+ x = torch.concat(
168
+ (
169
+ vision_outputs[-2],
170
+ vision_outputs[3],
171
+ vision_outputs[7],
172
+ vision_outputs[13],
173
+ vision_outputs[20],
174
+ ),
175
+ dim=-1,
176
+ )
177
+ assert len(x.shape) == 3
178
+ assert x.shape[-1] == vision_outputs[-2].shape[-1] * 5
179
+ else:
180
+ x = vision_outputs[-2]
181
+
182
+ x = self.ln1(x)
183
+
184
+ if self.pos_emb is not None:
185
+ assert x.shape[-2:] == self.pos_emb.shape
186
+ x = x + self.pos_emb
187
+
188
+ x = self.linear1(x)
189
+ x = self.activation(x)
190
+ x = self.linear2(x)
191
+
192
+ other_tokens = self.other_tokens(
193
+ torch.tensor([0, 1], device=self.other_tokens.weight.device).expand(x.shape[0], -1)
194
+ )
195
+ assert other_tokens.shape == (x.shape[0], 2, x.shape[2])
196
+ x = torch.cat((other_tokens[:, 0:1], x, other_tokens[:, 1:2]), dim=1)
197
+
198
+ return x
199
+
200
+ def get_eot_embedding(self):
201
+ return self.other_tokens(torch.tensor([2], device=self.other_tokens.weight.device)).squeeze(0)
202
+
203
+ # --- load_models Function ---
204
+ def load_models(CHECKPOINT_PATH, status_callback=None):
205
+ def update_status(msg):
206
+ if status_callback:
207
+ status_callback(msg)
208
+ print(msg) # Keep console output
209
+
210
+ update_status("Loading CLIP processor...")
211
+ clip_processor = AutoProcessor.from_pretrained(CLIP_PATH)
212
+ update_status("Loading CLIP vision model...")
213
+ clip_model = AutoModel.from_pretrained(CLIP_PATH)
214
+ clip_model = clip_model.vision_model
215
+
216
+ clip_model_path = CHECKPOINT_PATH / "clip_model.pt"
217
+ if not clip_model_path.exists():
218
+ raise FileNotFoundError(f"clip_model.pt not found in {CHECKPOINT_PATH}")
219
+
220
+ update_status("Loading VLM's custom vision weights...")
221
+ checkpoint = torch.load(clip_model_path, map_location="cpu")
222
+ checkpoint = {k.replace("_orig_mod.module.", ""): v for k, v in checkpoint.items()}
223
+ clip_model.load_state_dict(checkpoint)
224
+ del checkpoint
225
+
226
+ clip_model.eval()
227
+ clip_model.requires_grad_(False)
228
+ update_status(f"Moving CLIP to {device}...")
229
+ clip_model.to(device)
230
+
231
+ update_status("Loading tokenizer...")
232
+ tokenizer = AutoTokenizer.from_pretrained(
233
+ CHECKPOINT_PATH / "text_model", use_fast=True
234
+ )
235
+ if not isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
236
+ raise TypeError(f"Tokenizer is of type {type(tokenizer)}")
237
+
238
+ special_tokens_dict = {'additional_special_tokens': ['<|system|>', '<|user|>', '<|end|>', '<|eot_id|>']}
239
+ num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
240
+ update_status(f"Added {num_added_toks} special tokens.")
241
+
242
+ update_status("Loading LLM with 4-bit quantization (this may take time)...")
243
+ text_model = AutoModelForCausalLM.from_pretrained(
244
+ CHECKPOINT_PATH / "text_model",
245
+ device_map="auto",
246
+ quantization_config=BitsAndBytesConfig(
247
+ load_in_4bit=True,
248
+ bnb_4bit_use_double_quant=True,
249
+ bnb_4bit_quant_type='nf4',
250
+ bnb_4bit_compute_dtype=torch.float16
251
+ )
252
+ )
253
+ text_model.eval()
254
+
255
+ if num_added_toks > 0:
256
+ update_status("Resizing LLM token embeddings...")
257
+ text_model.resize_token_embeddings(len(tokenizer))
258
+
259
+ update_status("Loading image adapter...")
260
+ image_adapter = ImageAdapter(
261
+ clip_model.config.hidden_size, text_model.config.hidden_size, False, False, 38, False
262
+ )
263
+ image_adapter_path = CHECKPOINT_PATH / "image_adapter.pt"
264
+ if not image_adapter_path.exists():
265
+ raise FileNotFoundError(f"image_adapter.pt not found in {CHECKPOINT_PATH}")
266
+
267
+ image_adapter.load_state_dict(
268
+ torch.load(image_adapter_path, map_location="cpu")
269
+ )
270
+ image_adapter.eval()
271
+ update_status(f"Moving image adapter to {device}...")
272
+ image_adapter.to(device)
273
+
274
+ update_status("Models loaded successfully.")
275
+ return clip_processor, clip_model, tokenizer, text_model, image_adapter
276
+
277
+ # --- generate_caption Function ---
278
+ @torch.no_grad()
279
+ def generate_caption(
280
+ input_image: Image.Image,
281
+ caption_type: str,
282
+ caption_length: Union[str, int],
283
+ extra_options: List[str],
284
+ name_input: str,
285
+ custom_prompt: str,
286
+ clip_model,
287
+ tokenizer,
288
+ text_model,
289
+ image_adapter,
290
+ ) -> tuple:
291
+ if device.type == "cuda":
292
+ torch.cuda.empty_cache()
293
+
294
+ if custom_prompt.strip() != "":
295
+ prompt_str = custom_prompt.strip()
296
+ else:
297
+ length = None if caption_length == "any" else caption_length
298
+ if isinstance(length, str):
299
+ try:
300
+ length = int(length)
301
+ except ValueError:
302
+ pass
303
+
304
+ if length is None: map_idx = 0
305
+ elif isinstance(length, int): map_idx = 1
306
+ elif isinstance(length, str): map_idx = 2
307
+ else: raise ValueError(f"Invalid caption length: {length}")
308
+
309
+ prompt_str = CAPTION_TYPE_MAP[caption_type][map_idx]
310
+ if len(extra_options) > 0: prompt_str += " " + " ".join(extra_options)
311
+ prompt_str = prompt_str.format(name=name_input, length=caption_length, word_count=caption_length)
312
+
313
+ print(f"Prompt: {prompt_str}")
314
+
315
+ try:
316
+ image = input_image.convert("RGB")
317
+ except Exception as e: raise ValueError(f"Error converting image to RGB: {e}")
318
+ if image.mode != "RGB": raise ValueError(f"Image mode after conversion is {image.mode}, expected 'RGB'.")
319
+
320
+ image = image.resize((384, 384), Image.LANCZOS)
321
+ pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
322
+ pixel_values = TVF.normalize(pixel_values, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
323
+ pixel_values = pixel_values.to(device)
324
+
325
+ with autocast():
326
+ vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True)
327
+ embedded_images = image_adapter(vision_outputs.hidden_states)
328
+ embedded_images = embedded_images.to(device)
329
+
330
+ convo = [
331
+ {"role": "system", "content": "You are a helpful image captioner."},
332
+ {"role": "user", "content": prompt_str},
333
+ ]
334
+
335
+ if hasattr(tokenizer, "apply_chat_template"):
336
+ convo_string = tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
337
+ else:
338
+ convo_string = ("<|system|>\n" + convo[0]["content"] + "\n<|end|>\n<|user|>\n" + convo[1]["content"] + "\n<|end|>\n")
339
+ assert isinstance(convo_string, str)
340
+
341
+ convo_tokens = tokenizer.encode(convo_string, return_tensors="pt", add_special_tokens=False, truncation=False).to(device)
342
+ prompt_tokens = tokenizer.encode(prompt_str, return_tensors="pt", add_special_tokens=False, truncation=False).to(device)
343
+ assert isinstance(convo_tokens, torch.Tensor) and isinstance(prompt_tokens, torch.Tensor)
344
+ convo_tokens = convo_tokens.squeeze(0)
345
+ prompt_tokens = prompt_tokens.squeeze(0)
346
+
347
+ end_token_id = tokenizer.convert_tokens_to_ids("<|end|>")
348
+ if end_token_id is None: raise ValueError("Tokenizer missing '<|end|>' token.")
349
+ end_token_indices = (convo_tokens == end_token_id).nonzero(as_tuple=True)[0].tolist()
350
+ preamble_len = end_token_indices[0] + 1 if len(end_token_indices) >= 1 else 0
351
+
352
+ convo_embeds = text_model.model.embed_tokens(convo_tokens.unsqueeze(0).to(device))
353
+ input_embeds = torch.cat([
354
+ convo_embeds[:, :preamble_len],
355
+ embedded_images.to(dtype=convo_embeds.dtype),
356
+ convo_embeds[:, preamble_len:],
357
+ ], dim=1).to(device)
358
+
359
+ input_ids = torch.cat([
360
+ convo_tokens[:preamble_len].unsqueeze(0),
361
+ torch.full((1, embedded_images.shape[1]), tokenizer.pad_token_id, dtype=torch.long, device=device),
362
+ convo_tokens[preamble_len:].unsqueeze(0),
363
+ ], dim=1).to(device)
364
+ attention_mask = torch.ones_like(input_ids).to(device)
365
+
366
+ print(f"Input to model: {repr(tokenizer.decode(input_ids[0]))}")
367
+
368
+ generate_ids = text_model.generate(
369
+ input_ids=input_ids, inputs_embeds=input_embeds, attention_mask=attention_mask,
370
+ max_new_tokens=300, do_sample=True, temperature=0.6, top_p=0.9,
371
+ suppress_tokens=None, eos_token_id=[tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|end|>")]
372
+ )
373
+
374
+ generate_ids = generate_ids[:, input_ids.shape[1]:]
375
+ caption = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]
376
+ caption = caption.strip()
377
+ caption = re.sub(r'\s+', ' ', caption)
378
+
379
+ return prompt_str, caption
380
+
381
+ class CaptionApp(QMainWindow):
382
+ def __init__(self):
383
+ # ... (constructor unchanged) ...
384
+ super().__init__()
385
+ self.setWindowTitle("JoyCaption Alpha Two - Enhanced")
386
+ self.setGeometry(100, 100, 1200, 850)
387
+ self.setMinimumSize(1000, 750)
388
+
389
+ self.clip_processor = None
390
+ self.clip_model = None
391
+ self.tokenizer = None
392
+ self.text_model = None
393
+ self.image_adapter = None
394
+ self.models_loaded = False
395
+
396
+ self.input_dir = None
397
+ self.single_image_path = None
398
+ self.selected_image_path = None
399
+ self.image_files = []
400
+
401
+ self.dark_mode = False
402
+
403
+ self.central_widget = QWidget()
404
+ self.setCentralWidget(self.central_widget)
405
+ self.main_layout = QHBoxLayout(self.central_widget)
406
+
407
+ self.initUI() # Call initUI
408
+ self.update_button_states()
409
+ self.apply_theme()
410
+
411
+
412
+ def initUI(self):
413
+ # --- Left Panel ---
414
+ left_panel = QVBoxLayout()
415
+ left_panel.setSpacing(10)
416
+
417
+ # Input directory selection
418
+ dir_layout = QHBoxLayout()
419
+ self.input_dir_button = QPushButton("Select Input Directory")
420
+ self.input_dir_button.setToolTip("Select a folder containing images to process in batch.")
421
+ self.input_dir_button.clicked.connect(self.select_input_directory)
422
+ dir_layout.addWidget(self.input_dir_button)
423
+ self.input_dir_label = QLabel("No directory selected")
424
+ self.input_dir_label.setWordWrap(True)
425
+ dir_layout.addWidget(self.input_dir_label, 1)
426
+ left_panel.addLayout(dir_layout)
427
+
428
+ # Single image selection
429
+ single_img_layout = QHBoxLayout()
430
+ self.single_image_button = QPushButton("Select Single Image")
431
+ self.single_image_button.setToolTip("Select one image file to process.")
432
+ self.single_image_button.clicked.connect(self.select_single_image)
433
+ single_img_layout.addWidget(self.single_image_button)
434
+ self.single_image_label = QLabel("No image selected")
435
+ self.single_image_label.setWordWrap(True)
436
+ single_img_layout.addWidget(self.single_image_label, 1)
437
+ left_panel.addLayout(single_img_layout)
438
+
439
+ # Caption Type
440
+ self.caption_type_combo = QComboBox()
441
+ self.caption_type_combo.addItems(CAPTION_TYPE_MAP.keys())
442
+ self.caption_type_combo.setCurrentText("Descriptive")
443
+ self.caption_type_combo.setToolTip("Choose the style or purpose of the caption.")
444
+ left_panel.addWidget(QLabel("Caption Type:"))
445
+ left_panel.addWidget(self.caption_type_combo)
446
+
447
+ # Caption Length
448
+ self.caption_length_combo = QComboBox()
449
+ self.caption_length_combo.addItems(CAPTION_LENGTH_CHOICES)
450
+ self.caption_length_combo.setCurrentText("long")
451
+ self.caption_length_combo.setToolTip("Select desired caption length or word count.")
452
+ left_panel.addWidget(QLabel("Caption Length:"))
453
+ left_panel.addWidget(self.caption_length_combo)
454
+
455
+ # Extra Options
456
+ left_panel.addWidget(QLabel("Extra Options:"))
457
+ self.extra_options_checkboxes = []
458
+ for option in EXTRA_OPTIONS_LIST:
459
+ checkbox = QCheckBox(option)
460
+ checkbox.setToolTip(option)
461
+ self.extra_options_checkboxes.append(checkbox)
462
+ left_panel.addWidget(checkbox)
463
+
464
+ # Name Input
465
+ self.name_input_line = QLineEdit()
466
+ self.name_input_line.setPlaceholderText("e.g., 'the main character'")
467
+ self.name_input_line.setToolTip("If the first extra option is checked, this name will be used.")
468
+ left_panel.addWidget(QLabel("Person/Character Name (optional):"))
469
+ left_panel.addWidget(self.name_input_line)
470
+
471
+ # Custom Prompt
472
+ self.custom_prompt_text = QTextEdit()
473
+ self.custom_prompt_text.setPlaceholderText("Overrides Caption Type/Length/Options if used.")
474
+ self.custom_prompt_text.setToolTip("Enter a full custom prompt here to ignore other settings.")
475
+ self.custom_prompt_text.setFixedHeight(80)
476
+ left_panel.addWidget(QLabel("Custom Prompt (optional):"))
477
+ left_panel.addWidget(self.custom_prompt_text)
478
+
479
+ # Checkpoint Path
480
+ ckpt_layout = QHBoxLayout()
481
+ self.checkpoint_path_line = QLineEdit()
482
+ self.checkpoint_path_line.setText("cgrkzexw-599808")
483
+ self.checkpoint_path_line.setToolTip("Path to the folder containing model files (clip_model.pt, etc.).")
484
+ ckpt_layout.addWidget(QLabel("Checkpoint Path:"))
485
+ ckpt_layout.addWidget(self.checkpoint_path_line)
486
+ self.browse_ckpt_button = QPushButton("...")
487
+ self.browse_ckpt_button.setToolTip("Browse for Checkpoint Directory")
488
+ self.browse_ckpt_button.clicked.connect(self.browse_checkpoint_path)
489
+ self.browse_ckpt_button.setMaximumWidth(30)
490
+ ckpt_layout.addWidget(self.browse_ckpt_button)
491
+ left_panel.addLayout(ckpt_layout)
492
+
493
+ # Load Models Button
494
+ self.load_models_button = QPushButton("Load Models")
495
+ self.load_models_button.setToolTip("Load the AI models into memory (requires checkpoint path).")
496
+ self.load_models_button.clicked.connect(self.load_models_action)
497
+ left_panel.addWidget(self.load_models_button)
498
+
499
+
500
+
501
+ # Run Buttons
502
+ self.run_button = QPushButton("Generate Captions for All Images in Directory")
503
+ self.run_button.setToolTip("Process all loaded images from the selected directory.")
504
+ self.run_button.clicked.connect(self.generate_captions_action)
505
+ left_panel.addWidget(self.run_button)
506
+
507
+ self.caption_selected_button = QPushButton("Caption Selected Image from List")
508
+ self.caption_selected_button.setToolTip("Process the image currently highlighted in the list.")
509
+ self.caption_selected_button.clicked.connect(self.caption_selected_image_action)
510
+ left_panel.addWidget(self.caption_selected_button)
511
+
512
+ self.caption_single_button = QPushButton("Caption Single Loaded Image")
513
+ self.caption_single_button.setToolTip("Process the image selected via 'Select Single Image'.")
514
+ self.caption_single_button.clicked.connect(self.caption_single_image_action)
515
+ left_panel.addWidget(self.caption_single_button)
516
+
517
+ # Theme Toggle Button
518
+ self.toggle_theme_button = QPushButton("Toggle Dark Mode")
519
+ self.toggle_theme_button.setToolTip("Switch between light and dark themes.")
520
+ self.toggle_theme_button.clicked.connect(self.toggle_theme)
521
+ left_panel.addWidget(self.toggle_theme_button)
522
+
523
+ left_panel.addStretch(1)
524
+
525
+ # --- Right Panel ---
526
+ right_panel = QVBoxLayout()
527
+ right_panel.setSpacing(10)
528
+
529
+ # List widget for images
530
+ right_panel.addWidget(QLabel("Images in Directory:"))
531
+ self.image_list_widget = QListWidget()
532
+ self.image_list_widget.setIconSize(self.image_list_widget.iconSize() * 2)
533
+ self.image_list_widget.itemClicked.connect(self.display_selected_image)
534
+ self.image_list_widget.setToolTip("Click an image to view it and enable 'Caption Selected Image'.")
535
+ right_panel.addWidget(self.image_list_widget, 1)
536
+
537
+ # Label to display the selected image
538
+ right_panel.addWidget(QLabel("Selected Image Preview:"))
539
+ self.selected_image_label = QLabel("No image selected")
540
+ self.selected_image_label.setAlignment(Qt.AlignCenter)
541
+ self.selected_image_label.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
542
+ self.selected_image_label.setMinimumSize(300, 300)
543
+ self.selected_image_label.setStyleSheet("border: 1px solid gray;")
544
+ right_panel.addWidget(self.selected_image_label, 3)
545
+
546
+ # Generated Caption Area
547
+ right_panel.addWidget(QLabel("Generated/Editable Caption:"))
548
+ self.generated_caption_text = QTextEdit()
549
+ self.generated_caption_text.setReadOnly(False)
550
+ self.generated_caption_text.setPlaceholderText("Generated caption will appear here. You can edit it before saving.")
551
+ self.generated_caption_text.setToolTip("The generated caption appears here. Edit and use 'Save Edited Caption'.")
552
+ right_panel.addWidget(self.generated_caption_text, 1)
553
+
554
+
555
+
556
+ self.overwrite_checkbox = QCheckBox("Overwrite existing captions")
557
+ self.overwrite_checkbox.setToolTip("If checked, automatically overwrites existing .txt files without asking.")
558
+ self.append_checkbox = QCheckBox("Append to existing captions")
559
+ self.append_checkbox.setToolTip("If checked, adds the new caption to the end of the existing .txt file.")
560
+
561
+ # Layout for the save options
562
+ save_options_layout = QHBoxLayout()
563
+ save_options_layout.addWidget(self.overwrite_checkbox)
564
+ save_options_layout.addWidget(self.append_checkbox)
565
+ save_options_layout.addStretch(1)
566
+ right_panel.addLayout(save_options_layout)
567
+
568
+
569
+ self.append_checkbox.stateChanged.connect(
570
+ lambda state: self.overwrite_checkbox.setEnabled(state == Qt.Unchecked)
571
+ )
572
+
573
+
574
+ # Save Edited Caption Button
575
+ self.save_caption_button = QPushButton("Save Edited Caption to File")
576
+ self.save_caption_button.setToolTip("Save the text currently in the box above to the corresponding .txt file using the selected options.")
577
+ self.save_caption_button.clicked.connect(self.save_edited_caption_action)
578
+ right_panel.addWidget(self.save_caption_button)
579
+
580
+ # --- Main Layout Assembly
581
+ self.main_layout.addLayout(left_panel, 2)
582
+ self.main_layout.addLayout(right_panel, 5)
583
+
584
+ # --- Status Bar and Progress Bar
585
+ self.status_bar = QStatusBar()
586
+ self.setStatusBar(self.status_bar)
587
+ self.progress_bar = QProgressBar()
588
+ self.status_bar.addPermanentWidget(self.progress_bar)
589
+ self.progress_bar.hide()
590
+ self.show_status("Ready.", 5000)
591
+
592
+
593
+
594
+ import sys
595
+ import os
596
+ import torch
597
+ from torch import nn
598
+ from transformers import (
599
+ AutoModel,
600
+ AutoProcessor,
601
+ AutoTokenizer,
602
+ PreTrainedTokenizer,
603
+ PreTrainedTokenizerFast,
604
+ AutoModelForCausalLM,
605
+ BitsAndBytesConfig,
606
+ )
607
+ from PIL import Image
608
+ import torchvision.transforms.functional as TVF
609
+ import contextlib
610
+ from typing import Union, List
611
+ from pathlib import Path
612
+ import re # Added for spacing fix
613
+
614
+ from PyQt5.QtWidgets import (
615
+ QApplication,
616
+ QWidget,
617
+ QLabel,
618
+ QPushButton,
619
+ QFileDialog,
620
+ QLineEdit,
621
+ QTextEdit,
622
+ QComboBox,
623
+ QVBoxLayout,
624
+ QHBoxLayout,
625
+ QCheckBox,
626
+ QListWidget,
627
+ QListWidgetItem,
628
+ QMessageBox,
629
+ QSizePolicy,
630
+ QStatusBar,
631
+ QProgressBar,
632
+ QMainWindow,
633
+ )
634
+ from PyQt5.QtGui import QPixmap, QIcon
635
+ from PyQt5.QtCore import Qt, QTimer
636
+
637
+ # --- Constants and Mappings ---
638
+ CLIP_PATH = "google/siglip-so400m-patch14-384"
639
+ CHECKPOINT_PATH = Path("cgrkzexw-599808")
640
+ CAPTION_TYPE_MAP = {
641
+ "Descriptive": [
642
+ "Write a descriptive caption for this image in a formal tone.",
643
+ "Write a descriptive caption for this image in a formal tone within {word_count} words.",
644
+ "Write a {length} descriptive caption for this image in a formal tone.",
645
+ ],
646
+ "Descriptive (Informal)": [
647
+ "Write a descriptive caption for this image in a casual tone.",
648
+ "Write a descriptive caption for this image in a casual tone within {word_count} words.",
649
+ "Write a {length} descriptive caption for this image in a casual tone.",
650
+ ],
651
+ "Training Prompt": [
652
+ "Write a stable diffusion prompt for this image.",
653
+ "Write a stable diffusion prompt for this image within {word_count} words.",
654
+ "Write a {length} stable diffusion prompt for this image.",
655
+ ],
656
+ "MidJourney": [
657
+ "Write a MidJourney prompt for this image.",
658
+ "Write a MidJourney prompt for this image within {word_count} words.",
659
+ "Write a {length} MidJourney prompt for this image.",
660
+ ],
661
+ "Booru tag list": [
662
+ "Write a list of Booru tags for this image.",
663
+ "Write a list of Booru tags for this image within {word_count} words.",
664
+ "Write a {length} list of Booru tags for this image.",
665
+ ],
666
+ "Booru-like tag list": [
667
+ "Write a list of Booru-like tags for this image.",
668
+ "Write a list of Booru-like tags for this image within {word_count} words.",
669
+ "Write a {length} list of Booru-like tags for this image.",
670
+ ],
671
+ "Art Critic": [
672
+ "Analyze this image like an art critic would with information about its composition, style, symbolism, the use of color, light, any artistic movement it might belong to, etc.",
673
+ "Analyze this image like an art critic would with information about its composition, style, symbolism, the use of color, light, any artistic movement it might belong to, etc. Keep it within {word_count} words.",
674
+ "Analyze this image like an art critic would with information about its composition, style, symbolism, the use of color, light, any artistic movement it might belong to, etc. Keep it {length}.",
675
+ ],
676
+ "Product Listing": [
677
+ "Write a caption for this image as though it were a product listing.",
678
+ "Write a caption for this image as though it were a product listing. Keep it under {word_count} words.",
679
+ "Write a {length} caption for this image as though it were a product listing.",
680
+ ],
681
+ "Social Media Post": [
682
+ "Write a caption for this image as if it were being used for a social media post.",
683
+ "Write a caption for this image as if it were being used for a social media post. Limit the caption to {word_count} words.",
684
+ "Write a {length} caption for this image as if it were being used for a social media post.",
685
+ ],
686
+ }
687
+
688
+ EXTRA_OPTIONS_LIST = [
689
+ "If there is a person/character in the image you must refer to them as {name}.",
690
+ "Do NOT include information about people/characters that cannot be changed (like ethnicity, gender, etc), but do still include changeable attributes (like hair style).",
691
+ "Include information about lighting.",
692
+ "Include information about camera angle.",
693
+ "Include information about whether there is a watermark or not.",
694
+ "Include information about whether there are JPEG artifacts or not.",
695
+ "If it is a photo you MUST include information about what camera was likely used and details such as aperture, shutter speed, ISO, etc.",
696
+ "Do NOT include anything sexual; keep it PG.",
697
+ "Do NOT mention the image's resolution.",
698
+ "You MUST include information about the subjective aesthetic quality of the image from low to very high.",
699
+ "Include information on the image's composition style, such as leading lines, rule of thirds, or symmetry.",
700
+ "Do NOT mention any text that is in the image.",
701
+ "Specify the depth of field and whether the background is in focus or blurred.",
702
+ "If applicable, mention the likely use of artificial or natural lighting sources.",
703
+ "Do NOT use any ambiguous language.",
704
+ "Include whether the image is sfw, suggestive, or nsfw.",
705
+ "ONLY describe the most important elements of the image.",
706
+ ]
707
+
708
+ CAPTION_LENGTH_CHOICES = (
709
+ ["any", "very short", "short", "medium-length", "long", "very long"]
710
+ + [str(i) for i in range(20, 261, 10)]
711
+ )
712
+
713
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
714
+
715
+ # --- Device and Autocast Setup ---
716
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
717
+ if device.type == "cuda":
718
+ torch_dtype = torch.bfloat16
719
+ else:
720
+ torch_dtype = torch.float32
721
+
722
+ if device.type == "cuda":
723
+ autocast = lambda: torch.amp.autocast(device_type='cuda', dtype=torch_dtype)
724
+ else:
725
+ autocast = contextlib.nullcontext
726
+
727
+ # --- ImageAdapter Class ---
728
+ class ImageAdapter(nn.Module):
729
+ def __init__(
730
+ self,
731
+ input_features: int,
732
+ output_features: int,
733
+ ln1: bool,
734
+ pos_emb: bool,
735
+ num_image_tokens: int,
736
+ deep_extract: bool,
737
+ ):
738
+ super().__init__()
739
+ self.deep_extract = deep_extract
740
+
741
+ if self.deep_extract:
742
+ input_features = input_features * 5
743
+
744
+ self.linear1 = nn.Linear(input_features, output_features)
745
+ self.activation = nn.GELU()
746
+ self.linear2 = nn.Linear(output_features, output_features)
747
+ self.ln1 = nn.Identity() if not ln1 else nn.LayerNorm(input_features)
748
+ self.pos_emb = (
749
+ None if not pos_emb else nn.Parameter(torch.zeros(num_image_tokens, input_features))
750
+ )
751
+
752
+ # Other tokens (<|image_start|>, <|image_end|>, <|eot_id|>)
753
+ self.other_tokens = nn.Embedding(3, output_features)
754
+ self.other_tokens.weight.data.normal_(
755
+ mean=0.0, std=0.02
756
+ )
757
+
758
+ def forward(self, vision_outputs: torch.Tensor):
759
+ if self.deep_extract:
760
+ x = torch.concat(
761
+ (
762
+ vision_outputs[-2],
763
+ vision_outputs[3],
764
+ vision_outputs[7],
765
+ vision_outputs[13],
766
+ vision_outputs[20],
767
+ ),
768
+ dim=-1,
769
+ )
770
+ assert len(x.shape) == 3
771
+ assert x.shape[-1] == vision_outputs[-2].shape[-1] * 5
772
+ else:
773
+ x = vision_outputs[-2]
774
+
775
+ x = self.ln1(x)
776
+
777
+ if self.pos_emb is not None:
778
+ assert x.shape[-2:] == self.pos_emb.shape
779
+ x = x + self.pos_emb
780
+
781
+ x = self.linear1(x)
782
+ x = self.activation(x)
783
+ x = self.linear2(x)
784
+
785
+ other_tokens = self.other_tokens(
786
+ torch.tensor([0, 1], device=self.other_tokens.weight.device).expand(x.shape[0], -1)
787
+ )
788
+ assert other_tokens.shape == (x.shape[0], 2, x.shape[2])
789
+ x = torch.cat((other_tokens[:, 0:1], x, other_tokens[:, 1:2]), dim=1)
790
+
791
+ return x
792
+
793
+ def get_eot_embedding(self):
794
+ return self.other_tokens(torch.tensor([2], device=self.other_tokens.weight.device)).squeeze(0)
795
+
796
+ # --- load_models Function ---
797
+ def load_models(CHECKPOINT_PATH, status_callback=None):
798
+ def update_status(msg):
799
+ if status_callback:
800
+ status_callback(msg)
801
+ print(msg) # Keep console output
802
+
803
+ update_status("Loading CLIP processor...")
804
+ clip_processor = AutoProcessor.from_pretrained(CLIP_PATH)
805
+ update_status("Loading CLIP vision model...")
806
+ clip_model = AutoModel.from_pretrained(CLIP_PATH)
807
+ clip_model = clip_model.vision_model
808
+
809
+ clip_model_path = CHECKPOINT_PATH / "clip_model.pt"
810
+ if not clip_model_path.exists():
811
+ raise FileNotFoundError(f"clip_model.pt not found in {CHECKPOINT_PATH}")
812
+
813
+ update_status("Loading VLM's custom vision weights...")
814
+ checkpoint = torch.load(clip_model_path, map_location="cpu")
815
+ checkpoint = {k.replace("_orig_mod.module.", ""): v for k, v in checkpoint.items()}
816
+ clip_model.load_state_dict(checkpoint)
817
+ del checkpoint
818
+
819
+ clip_model.eval()
820
+ clip_model.requires_grad_(False)
821
+ update_status(f"Moving CLIP to {device}...")
822
+ clip_model.to(device)
823
+
824
+ update_status("Loading tokenizer...")
825
+ tokenizer = AutoTokenizer.from_pretrained(
826
+ CHECKPOINT_PATH / "text_model", use_fast=True
827
+ )
828
+ if not isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
829
+ raise TypeError(f"Tokenizer is of type {type(tokenizer)}")
830
+
831
+ special_tokens_dict = {'additional_special_tokens': ['<|system|>', '<|user|>', '<|end|>', '<|eot_id|>']}
832
+ num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
833
+ update_status(f"Added {num_added_toks} special tokens.")
834
+
835
+ update_status("Loading LLM with 4-bit quantization (this may take time)...")
836
+ text_model = AutoModelForCausalLM.from_pretrained(
837
+ CHECKPOINT_PATH / "text_model",
838
+ device_map="auto",
839
+ quantization_config=BitsAndBytesConfig(
840
+ load_in_4bit=True,
841
+ bnb_4bit_use_double_quant=True,
842
+ bnb_4bit_quant_type='nf4',
843
+ bnb_4bit_compute_dtype=torch.float16
844
+ )
845
+ )
846
+ text_model.eval()
847
+
848
+ if num_added_toks > 0:
849
+ update_status("Resizing LLM token embeddings...")
850
+ text_model.resize_token_embeddings(len(tokenizer))
851
+
852
+ update_status("Loading image adapter...")
853
+ image_adapter = ImageAdapter(
854
+ clip_model.config.hidden_size, text_model.config.hidden_size, False, False, 38, False
855
+ )
856
+ image_adapter_path = CHECKPOINT_PATH / "image_adapter.pt"
857
+ if not image_adapter_path.exists():
858
+ raise FileNotFoundError(f"image_adapter.pt not found in {CHECKPOINT_PATH}")
859
+
860
+ image_adapter.load_state_dict(
861
+ torch.load(image_adapter_path, map_location="cpu")
862
+ )
863
+ image_adapter.eval()
864
+ update_status(f"Moving image adapter to {device}...")
865
+ image_adapter.to(device)
866
+
867
+ update_status("Models loaded successfully.")
868
+ return clip_processor, clip_model, tokenizer, text_model, image_adapter
869
+
870
+ # --- generate_caption Function ---
871
+ @torch.no_grad()
872
+ def generate_caption(
873
+ input_image: Image.Image,
874
+ caption_type: str,
875
+ caption_length: Union[str, int],
876
+ extra_options: List[str],
877
+ name_input: str,
878
+ custom_prompt: str,
879
+ clip_model,
880
+ tokenizer,
881
+ text_model,
882
+ image_adapter,
883
+ ) -> tuple:
884
+ if device.type == "cuda":
885
+ torch.cuda.empty_cache()
886
+
887
+ if custom_prompt.strip() != "":
888
+ prompt_str = custom_prompt.strip()
889
+ else:
890
+ length = None if caption_length == "any" else caption_length
891
+ if isinstance(length, str):
892
+ try:
893
+ length = int(length)
894
+ except ValueError:
895
+ pass
896
+
897
+ if length is None: map_idx = 0
898
+ elif isinstance(length, int): map_idx = 1
899
+ elif isinstance(length, str): map_idx = 2
900
+ else: raise ValueError(f"Invalid caption length: {length}")
901
+
902
+ prompt_str = CAPTION_TYPE_MAP[caption_type][map_idx]
903
+ if len(extra_options) > 0: prompt_str += " " + " ".join(extra_options)
904
+ prompt_str = prompt_str.format(name=name_input, length=caption_length, word_count=caption_length)
905
+
906
+ print(f"Prompt: {prompt_str}")
907
+
908
+ try:
909
+ image = input_image.convert("RGB")
910
+ except Exception as e: raise ValueError(f"Error converting image to RGB: {e}")
911
+ if image.mode != "RGB": raise ValueError(f"Image mode after conversion is {image.mode}, expected 'RGB'.")
912
+
913
+ image = image.resize((384, 384), Image.LANCZOS)
914
+ pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
915
+ pixel_values = TVF.normalize(pixel_values, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
916
+ pixel_values = pixel_values.to(device)
917
+
918
+ with autocast():
919
+ vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True)
920
+ embedded_images = image_adapter(vision_outputs.hidden_states)
921
+ embedded_images = embedded_images.to(device)
922
+
923
+ convo = [
924
+ {"role": "system", "content": "You are a helpful image captioner."},
925
+ {"role": "user", "content": prompt_str},
926
+ ]
927
+
928
+ if hasattr(tokenizer, "apply_chat_template"):
929
+ convo_string = tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
930
+ else:
931
+ convo_string = ("<|system|>\n" + convo[0]["content"] + "\n<|end|>\n<|user|>\n" + convo[1]["content"] + "\n<|end|>\n")
932
+ assert isinstance(convo_string, str)
933
+
934
+ convo_tokens = tokenizer.encode(convo_string, return_tensors="pt", add_special_tokens=False, truncation=False).to(device)
935
+ prompt_tokens = tokenizer.encode(prompt_str, return_tensors="pt", add_special_tokens=False, truncation=False).to(device)
936
+ assert isinstance(convo_tokens, torch.Tensor) and isinstance(prompt_tokens, torch.Tensor)
937
+ convo_tokens = convo_tokens.squeeze(0)
938
+ prompt_tokens = prompt_tokens.squeeze(0)
939
+
940
+ end_token_id = tokenizer.convert_tokens_to_ids("<|end|>")
941
+ if end_token_id is None: raise ValueError("Tokenizer missing '<|end|>' token.")
942
+ end_token_indices = (convo_tokens == end_token_id).nonzero(as_tuple=True)[0].tolist()
943
+ preamble_len = end_token_indices[0] + 1 if len(end_token_indices) >= 1 else 0
944
+
945
+ convo_embeds = text_model.model.embed_tokens(convo_tokens.unsqueeze(0).to(device))
946
+ input_embeds = torch.cat([
947
+ convo_embeds[:, :preamble_len],
948
+ embedded_images.to(dtype=convo_embeds.dtype),
949
+ convo_embeds[:, preamble_len:],
950
+ ], dim=1).to(device)
951
+
952
+ input_ids = torch.cat([
953
+ convo_tokens[:preamble_len].unsqueeze(0),
954
+ torch.full((1, embedded_images.shape[1]), tokenizer.pad_token_id, dtype=torch.long, device=device),
955
+ convo_tokens[preamble_len:].unsqueeze(0),
956
+ ], dim=1).to(device)
957
+ attention_mask = torch.ones_like(input_ids).to(device)
958
+
959
+ print(f"Input to model: {repr(tokenizer.decode(input_ids[0]))}")
960
+
961
+ generate_ids = text_model.generate(
962
+ input_ids=input_ids, inputs_embeds=input_embeds, attention_mask=attention_mask,
963
+ max_new_tokens=300, do_sample=True, temperature=0.6, top_p=0.9,
964
+ suppress_tokens=None, eos_token_id=[tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|end|>")]
965
+ )
966
+
967
+ generate_ids = generate_ids[:, input_ids.shape[1]:]
968
+ caption = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]
969
+ caption = caption.strip()
970
+ caption = re.sub(r'\s+', ' ', caption)
971
+
972
+ return prompt_str, caption
973
+
974
+ # --- CaptionApp Class ---
975
+ class CaptionApp(QMainWindow):
976
+ def __init__(self):
977
+ super().__init__()
978
+ self.setWindowTitle("JoyCaption Alpha Two - Enhanced")
979
+ self.setGeometry(100, 100, 1200, 850)
980
+ self.setMinimumSize(1000, 750)
981
+
982
+ self.clip_processor = None
983
+ self.clip_model = None
984
+ self.tokenizer = None
985
+ self.text_model = None
986
+ self.image_adapter = None
987
+ self.models_loaded = False
988
+
989
+ self.input_dir = None
990
+ self.single_image_path = None
991
+ self.selected_image_path = None
992
+ self.image_files = []
993
+
994
+ self.dark_mode = False
995
+
996
+ self.central_widget = QWidget()
997
+ self.setCentralWidget(self.central_widget)
998
+ self.main_layout = QHBoxLayout(self.central_widget)
999
+
1000
+ self.initUI() # Call initUI
1001
+
1002
+ # Attempt to auto-load models at startup
1003
+
1004
+ if Path("cgrkzexw-599808").exists():
1005
+ try:
1006
+ (self.clip_processor, self.clip_model, self.tokenizer, self.text_model, self.image_adapter) = load_models(Path("cgrkzexw-599808"), status_callback=self.show_status)
1007
+ self.models_loaded = True
1008
+ self.show_status("Models loaded at startup.", 5000)
1009
+ except Exception as e:
1010
+ print("Auto-load failed:", e)
1011
+ self.models_loaded = False
1012
+
1013
+ self.update_button_states()
1014
+ self.apply_theme()
1015
+
1016
+ def initUI(self):
1017
+ # --- Left Panel ---
1018
+ left_panel = QVBoxLayout()
1019
+ left_panel.setSpacing(10)
1020
+
1021
+ # Input directory selection
1022
+ dir_layout = QHBoxLayout()
1023
+ self.input_dir_button = QPushButton("Select Input Directory")
1024
+ self.input_dir_button.setToolTip("Select a folder containing images to process in batch.")
1025
+ self.input_dir_button.clicked.connect(self.select_input_directory)
1026
+ dir_layout.addWidget(self.input_dir_button)
1027
+ self.input_dir_label = QLabel("No directory selected")
1028
+ self.input_dir_label.setWordWrap(True)
1029
+ dir_layout.addWidget(self.input_dir_label, 1)
1030
+ left_panel.addLayout(dir_layout)
1031
+
1032
+ # Single image selection
1033
+ single_img_layout = QHBoxLayout()
1034
+ self.single_image_button = QPushButton("Select Single Image")
1035
+ self.single_image_button.setToolTip("Select one image file to process.")
1036
+ self.single_image_button.clicked.connect(self.select_single_image)
1037
+ single_img_layout.addWidget(self.single_image_button)
1038
+ self.single_image_label = QLabel("No image selected")
1039
+ self.single_image_label.setWordWrap(True)
1040
+ single_img_layout.addWidget(self.single_image_label, 1)
1041
+ left_panel.addLayout(single_img_layout)
1042
+
1043
+ # Caption Type
1044
+ self.caption_type_combo = QComboBox()
1045
+ self.caption_type_combo.addItems(CAPTION_TYPE_MAP.keys())
1046
+ self.caption_type_combo.setCurrentText("Descriptive")
1047
+ self.caption_type_combo.setToolTip("Choose the style or purpose of the caption.")
1048
+ left_panel.addWidget(QLabel("Caption Type:"))
1049
+ left_panel.addWidget(self.caption_type_combo)
1050
+
1051
+ # Caption Length
1052
+ self.caption_length_combo = QComboBox()
1053
+ self.caption_length_combo.addItems(CAPTION_LENGTH_CHOICES)
1054
+ self.caption_length_combo.setCurrentText("long")
1055
+ self.caption_length_combo.setToolTip("Select desired caption length or word count.")
1056
+ left_panel.addWidget(QLabel("Caption Length:"))
1057
+ left_panel.addWidget(self.caption_length_combo)
1058
+
1059
+ # Extra Options
1060
+ left_panel.addWidget(QLabel("Extra Options:"))
1061
+ self.extra_options_checkboxes = []
1062
+ for option in EXTRA_OPTIONS_LIST:
1063
+ checkbox = QCheckBox(option)
1064
+ checkbox.setToolTip(option)
1065
+ self.extra_options_checkboxes.append(checkbox)
1066
+ left_panel.addWidget(checkbox)
1067
+
1068
+ # Name Input
1069
+ self.name_input_line = QLineEdit()
1070
+ self.name_input_line.setPlaceholderText("e.g., 'the main character'")
1071
+ self.name_input_line.setToolTip("If the first extra option is checked, this name will be used.")
1072
+ left_panel.addWidget(QLabel("Person/Character Name (optional):"))
1073
+ left_panel.addWidget(self.name_input_line)
1074
+
1075
+ # Custom Prompt
1076
+ self.custom_prompt_text = QTextEdit()
1077
+ self.custom_prompt_text.setPlaceholderText("Overrides Caption Type/Length/Options if used.")
1078
+ self.custom_prompt_text.setToolTip("Enter a full custom prompt here to ignore other settings.")
1079
+ self.custom_prompt_text.setFixedHeight(80)
1080
+ left_panel.addWidget(QLabel("Custom Prompt (optional):"))
1081
+ left_panel.addWidget(self.custom_prompt_text)
1082
+
1083
+ # Checkpoint Path
1084
+ ckpt_layout = QHBoxLayout()
1085
+ self.checkpoint_path_line = QLineEdit()
1086
+ self.checkpoint_path_line.setToolTip("Path to the folder containing model files (clip_model.pt, etc.).")
1087
+ ckpt_layout.addWidget(QLabel("Checkpoint Path:"))
1088
+ ckpt_layout.addWidget(self.checkpoint_path_line)
1089
+ self.browse_ckpt_button = QPushButton("...")
1090
+ self.browse_ckpt_button.setToolTip("Browse for Checkpoint Directory")
1091
+ self.browse_ckpt_button.clicked.connect(self.browse_checkpoint_path)
1092
+ self.browse_ckpt_button.setMaximumWidth(30)
1093
+ ckpt_layout.addWidget(self.browse_ckpt_button)
1094
+ left_panel.addLayout(ckpt_layout)
1095
+
1096
+ # Load Models Button
1097
+ self.load_models_button = QPushButton("Load Models")
1098
+ self.load_models_button.setToolTip("Load the AI models into memory (requires checkpoint path).")
1099
+ self.load_models_button.clicked.connect(self.load_models_action)
1100
+ left_panel.addWidget(self.load_models_button)
1101
+
1102
+ # Run Buttons
1103
+ self.run_button = QPushButton("Generate Captions for All Images in Directory")
1104
+ self.run_button.setToolTip("Process all loaded images from the selected directory.")
1105
+ self.run_button.clicked.connect(self.generate_captions_action)
1106
+ left_panel.addWidget(self.run_button)
1107
+
1108
+ self.caption_selected_button = QPushButton("Caption Selected Image from List")
1109
+ self.caption_selected_button.setToolTip("Process the image currently highlighted in the list.")
1110
+ self.caption_selected_button.clicked.connect(self.caption_selected_image_action)
1111
+ left_panel.addWidget(self.caption_selected_button)
1112
+
1113
+ self.caption_single_button = QPushButton("Caption Single Loaded Image")
1114
+ self.caption_single_button.setToolTip("Process the image selected via 'Select Single Image'.")
1115
+ self.caption_single_button.clicked.connect(self.caption_single_image_action)
1116
+ left_panel.addWidget(self.caption_single_button)
1117
+
1118
+ # Theme Toggle Button
1119
+ self.toggle_theme_button = QPushButton("Toggle Dark Mode")
1120
+ self.toggle_theme_button.setToolTip("Switch between light and dark themes.")
1121
+ self.toggle_theme_button.clicked.connect(self.toggle_theme)
1122
+ left_panel.addWidget(self.toggle_theme_button)
1123
+
1124
+ left_panel.addStretch(1)
1125
+
1126
+ # --- Right Panel ---
1127
+ right_panel = QVBoxLayout()
1128
+ right_panel.setSpacing(10)
1129
+
1130
+ # List widget for images
1131
+ right_panel.addWidget(QLabel("Images in Directory:"))
1132
+ self.image_list_widget = QListWidget()
1133
+ self.image_list_widget.setIconSize(self.image_list_widget.iconSize() * 2)
1134
+ self.image_list_widget.itemClicked.connect(self.display_selected_image)
1135
+ self.image_list_widget.setToolTip("Click an image to view it and enable 'Caption Selected Image'.")
1136
+ right_panel.addWidget(self.image_list_widget, 1)
1137
+
1138
+ # Label to display the selected image
1139
+ right_panel.addWidget(QLabel("Selected Image Preview:"))
1140
+ self.selected_image_label = QLabel("No image selected")
1141
+ self.selected_image_label.setAlignment(Qt.AlignCenter)
1142
+ self.selected_image_label.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
1143
+ self.selected_image_label.setMinimumSize(300, 300)
1144
+ self.selected_image_label.setStyleSheet("border: 1px solid gray;")
1145
+ right_panel.addWidget(self.selected_image_label, 3)
1146
+
1147
+ # Generated Caption Area
1148
+ right_panel.addWidget(QLabel("Generated/Editable Caption:"))
1149
+ self.generated_caption_text = QTextEdit()
1150
+ self.generated_caption_text.setReadOnly(False)
1151
+ self.generated_caption_text.setPlaceholderText("Generated caption will appear here. You can edit it before saving.")
1152
+ self.generated_caption_text.setToolTip("The generated caption appears here. Edit and use 'Save Edited Caption'.")
1153
+ right_panel.addWidget(self.generated_caption_text, 1)
1154
+
1155
+ # Saving Options
1156
+ self.overwrite_checkbox = QCheckBox("Overwrite existing captions")
1157
+ self.overwrite_checkbox.setToolTip("If checked, automatically overwrites existing .txt files without asking.")
1158
+ self.append_checkbox = QCheckBox("Append to existing captions")
1159
+ self.append_checkbox.setToolTip("If checked, adds the new caption to the end of the existing .txt file.")
1160
+
1161
+ save_options_layout = QHBoxLayout()
1162
+ save_options_layout.addWidget(self.overwrite_checkbox)
1163
+ save_options_layout.addWidget(self.append_checkbox)
1164
+ save_options_layout.addStretch(1)
1165
+ right_panel.addLayout(save_options_layout) # Add layout here
1166
+
1167
+ self.append_checkbox.stateChanged.connect(
1168
+ lambda state: self.overwrite_checkbox.setEnabled(state == Qt.Unchecked)
1169
+ )
1170
+
1171
+ # Save Edited Caption Button
1172
+ self.save_caption_button = QPushButton("Save Edited Caption to File")
1173
+ self.save_caption_button.setToolTip("Save the text currently in the box above to the corresponding .txt file using the selected options.")
1174
+ self.save_caption_button.clicked.connect(self.save_edited_caption_action)
1175
+ right_panel.addWidget(self.save_caption_button)
1176
+
1177
+ # --- Main Layout Assembly ---
1178
+ self.main_layout.addLayout(left_panel, 2)
1179
+ self.main_layout.addLayout(right_panel, 5)
1180
+
1181
+ # --- Status Bar and Progress Bar ---
1182
+ self.status_bar = QStatusBar()
1183
+ self.setStatusBar(self.status_bar)
1184
+ self.progress_bar = QProgressBar()
1185
+ self.status_bar.addPermanentWidget(self.progress_bar)
1186
+ self.progress_bar.hide()
1187
+ self.show_status("Ready.", 5000)
1188
+
1189
+ def browse_checkpoint_path(self):
1190
+ directory = QFileDialog.getExistingDirectory(self, "Select Checkpoint Directory")
1191
+ if directory:
1192
+ self.checkpoint_path_line.setText(directory)
1193
+ self.update_button_states()
1194
+
1195
+ def show_status(self, message, timeout=0):
1196
+ self.status_bar.showMessage(message, timeout)
1197
+ QApplication.processEvents()
1198
+
1199
+ def update_button_states(self):
1200
+ self.load_models_button.setEnabled(bool(self.checkpoint_path_line.text()))
1201
+ models_ready = self.models_loaded
1202
+ dir_selected = self.input_dir is not None and bool(self.image_files)
1203
+ single_img_selected = self.single_image_path is not None
1204
+ list_img_selected = self.selected_image_path is not None
1205
+ caption_present = bool(self.generated_caption_text.toPlainText().strip())
1206
+
1207
+ self.run_button.setEnabled(models_ready and dir_selected)
1208
+ self.caption_selected_button.setEnabled(models_ready and list_img_selected)
1209
+ self.caption_single_button.setEnabled(models_ready and single_img_selected)
1210
+ self.save_caption_button.setEnabled(caption_present and (list_img_selected or single_img_selected))
1211
+
1212
+ def apply_theme(self):
1213
+ dark_stylesheet = """
1214
+ QMainWindow, QWidget { background-color: #2E2E2E; color: #FFFFFF; font-family: Arial, sans-serif; }
1215
+ QPushButton { background-color: #3A3A3A; color: #FFFFFF; border: 1px solid #555555; padding: 5px; min-height: 20px; }
1216
+ QPushButton:hover { background-color: #555555; }
1217
+ QPushButton:disabled { background-color: #454545; color: #888888; }
1218
+ QLabel { color: #FFFFFF; }
1219
+ QLineEdit, QTextEdit, QComboBox { background-color: #3A3A3A; color: #FFFFFF; border: 1px solid #555555; padding: 4px; }
1220
+ QLineEdit:disabled, QTextEdit:disabled, QComboBox:disabled { background-color: #454545; color: #888888; }
1221
+ QListWidget { background-color: #3A3A3A; color: #FFFFFF; border: 1px solid #555555; alternate-background-color: #424242; }
1222
+ QCheckBox { color: #FFFFFF; spacing: 5px; }
1223
+ QCheckBox::indicator { width: 13px; height: 13px; }
1224
+ QStatusBar { color: #FFFFFF; } QStatusBar::item { border: none; }
1225
+ QProgressBar { border: 1px solid #555555; text-align: center; color: #FFFFFF; background-color: #3A3A3A; }
1226
+ QProgressBar::chunk { background-color: #007ADF; width: 10px; margin: 0.5px; }
1227
+ QToolTip { background-color: #464646; color: #FFFFFF; border: 1px solid #555555; padding: 4px; }
1228
+ QTextEdit { placeholderTextColor: gray; } QLineEdit { placeholderTextColor: gray; }
1229
+ """
1230
+ if self.dark_mode: self.setStyleSheet(dark_stylesheet)
1231
+ else: self.setStyleSheet("")
1232
+
1233
+ placeholder_style = "QTextEdit { placeholderTextColor: gray; } QLineEdit { placeholderTextColor: gray; }"
1234
+ current_style = self.styleSheet()
1235
+ if self.dark_mode:
1236
+ if "placeholderTextColor" not in current_style: self.setStyleSheet(current_style + placeholder_style)
1237
+ else: self.setStyleSheet(current_style.replace(placeholder_style, ""))
1238
+
1239
+ def toggle_theme(self):
1240
+ self.dark_mode = not self.dark_mode
1241
+ self.apply_theme()
1242
+
1243
+ def select_input_directory(self):
1244
+ directory = QFileDialog.getExistingDirectory(self, "Select Input Directory")
1245
+ if directory:
1246
+ self.input_dir = Path(directory)
1247
+ self.input_dir_label.setText(str(self.input_dir))
1248
+ self.single_image_path = None; self.single_image_label.setText("No image selected")
1249
+ self.selected_image_path = None; self.selected_image_label.setText("No image selected")
1250
+ self.generated_caption_text.clear()
1251
+ self.load_images()
1252
+ self.show_status(f"Selected directory: {self.input_dir.name}", 5000)
1253
+ else:
1254
+ self.input_dir_label.setText("No directory selected"); self.input_dir = None
1255
+ self.image_list_widget.clear(); self.image_files = []
1256
+ self.show_status("Directory selection cancelled.", 3000)
1257
+ self.update_button_states()
1258
+
1259
+ def select_single_image(self):
1260
+ file_filter = "Image Files (*.jpg *.jpeg *.png *.bmp *.gif *.tiff *.webp)"
1261
+ file_path, _ = QFileDialog.getOpenFileName(self, "Select Single Image", "", file_filter)
1262
+ if file_path:
1263
+ self.single_image_path = Path(file_path)
1264
+ self.single_image_label.setText(str(self.single_image_path.name))
1265
+ self.input_dir = None; self.input_dir_label.setText("No directory selected")
1266
+ self.image_list_widget.clear(); self.image_files = []
1267
+ self.selected_image_path = None
1268
+ self.display_image(self.single_image_path)
1269
+ self.show_status(f"Selected single image: {self.single_image_path.name}", 5000)
1270
+ else:
1271
+ self.single_image_label.setText("No image selected"); self.single_image_path = None
1272
+ self.show_status("Single image selection cancelled.", 3000)
1273
+ self.update_button_states()
1274
+
1275
+ def load_images(self):
1276
+ if not self.input_dir: return
1277
+ self.show_status(f"Loading images from {self.input_dir.name}...")
1278
+ image_extensions = [".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff", ".webp"]
1279
+ try:
1280
+ self.image_files = sorted([f for f in self.input_dir.iterdir() if f.is_file() and f.suffix.lower() in image_extensions])
1281
+ except Exception as e:
1282
+ QMessageBox.critical(self, "Directory Error", f"Could not read directory contents:\n{e}")
1283
+ self.show_status(f"Error reading directory {self.input_dir.name}", 5000)
1284
+ self.image_files = []; self.input_dir = None; self.input_dir_label.setText("Error reading directory")
1285
+
1286
+ self.image_list_widget.clear()
1287
+ if not self.image_files:
1288
+ if self.input_dir:
1289
+ QMessageBox.warning(self, "No Images", "No supported image files found.")
1290
+ self.show_status("No images found in directory.", 3000)
1291
+ self.update_button_states()
1292
+ return
1293
+
1294
+ thumb_size = 100
1295
+ for image_path in self.image_files:
1296
+ item = QListWidgetItem(str(image_path.name))
1297
+ try:
1298
+ pixmap = QPixmap(str(image_path))
1299
+ if not pixmap.isNull():
1300
+ scaled_pixmap = pixmap.scaled(thumb_size, thumb_size, Qt.KeepAspectRatio, Qt.SmoothTransformation)
1301
+ item.setIcon(QIcon(scaled_pixmap))
1302
+ else: print(f"Warning: QPixmap is null for {image_path.name}")
1303
+ except Exception as e: print(f"Warning: Could not create thumbnail for {image_path.name}: {e}")
1304
+ self.image_list_widget.addItem(item)
1305
+
1306
+ self.show_status(f"Loaded {len(self.image_files)} images.", 5000)
1307
+ self.update_button_states()
1308
+
1309
+ def display_selected_image(self, item):
1310
+ if not self.input_dir or not item: return
1311
+ try:
1312
+ image_name = item.text()
1313
+ image_path = self.input_dir / image_name
1314
+ if not image_path.exists():
1315
+ QMessageBox.warning(self, "File Not Found", f"Image file '{image_name}' no longer exists.")
1316
+ self.selected_image_label.setText("File not found")
1317
+ self.selected_image_label.setPixmap(QPixmap())
1318
+ self.generated_caption_text.clear()
1319
+ self.selected_image_path = None
1320
+ return
1321
+
1322
+ self.selected_image_path = image_path
1323
+ self.single_image_path = None
1324
+ self.single_image_label.setText("No image selected")
1325
+ self.display_image(image_path)
1326
+ caption_file_path = image_path.with_suffix('.txt')
1327
+ if caption_file_path.exists():
1328
+ try:
1329
+ with open(caption_file_path, 'r', encoding='utf-8') as f:
1330
+ caption_content = f.read()
1331
+ self.generated_caption_text.setText(caption_content)
1332
+ status_message = f"Displayed {image_name} and loaded existing caption."
1333
+ except Exception as e:
1334
+ print(f"Warning: Could not read caption file {caption_file_path.name}: {e}")
1335
+ # Keep caption box clear or show error placeholder
1336
+ self.generated_caption_text.setPlaceholderText(f"Error reading caption file for {image_name}.")
1337
+ status_message = f"Displayed {image_name}, but failed to load caption file."
1338
+ else:
1339
+ # Keep caption box clear (already done by display_image)
1340
+ self.generated_caption_text.setPlaceholderText("Generate or edit caption here.")
1341
+ status_message = f"Displayed {image_name}. No existing caption found."
1342
+ self.show_status(f"Selected {image_name} from list.", 4000)
1343
+ except Exception as e:
1344
+ self.selected_image_label.setText("Error loading preview")
1345
+ self.selected_image_path = None
1346
+ QMessageBox.warning(self, "Preview Error", f"Could not load preview for {item.text()}: {e}")
1347
+ self.show_status(f"Error loading preview for {item.text()}", 4000)
1348
+ self.update_button_states()
1349
+
1350
+ def display_image(self, image_path):
1351
+ try:
1352
+ pixmap = QPixmap(str(image_path))
1353
+ if not pixmap.isNull():
1354
+ self.scale_and_set_pixmap(pixmap)
1355
+ self.generated_caption_text.clear()
1356
+ else:
1357
+ self.selected_image_label.setText(f"Cannot display image:\n{image_path.name}")
1358
+ self.selected_image_label.setPixmap(QPixmap())
1359
+ except Exception as e:
1360
+ self.selected_image_label.setText(f"Error loading preview:\n{image_path.name}")
1361
+ self.selected_image_label.setPixmap(QPixmap())
1362
+ print(f"Error displaying image {image_path}: {e}")
1363
+ self.show_status(f"Error displaying image {image_path.name}", 4000)
1364
+ self.update_button_states()
1365
+
1366
+ def scale_and_set_pixmap(self, pixmap):
1367
+ if not pixmap or pixmap.isNull():
1368
+ self.selected_image_label.clear()
1369
+ self.selected_image_label.setText("No image selected")
1370
+ return
1371
+ label_size = self.selected_image_label.contentsRect().size()
1372
+ scaled_pixmap = pixmap.scaled(label_size * self.devicePixelRatioF(), Qt.KeepAspectRatio, Qt.SmoothTransformation)
1373
+ self.selected_image_label.setPixmap(scaled_pixmap)
1374
+
1375
+ def load_models_action(self):
1376
+ checkpoint_path_str = self.checkpoint_path_line.text()
1377
+ if not checkpoint_path_str: QMessageBox.warning(self, "Checkpoint Error", "Please specify the checkpoint path."); return
1378
+ checkpoint_path = Path(checkpoint_path_str)
1379
+ if not checkpoint_path.exists() or not checkpoint_path.is_dir():
1380
+ QMessageBox.warning(self, "Checkpoint Error", f"Checkpoint path does not exist or is not a directory:\n{checkpoint_path}"); return
1381
+
1382
+ self.show_status("Loading models... This might take a while.", 0)
1383
+ self.progress_bar.setRange(0, 0); self.progress_bar.show(); QApplication.processEvents()
1384
+ try:
1385
+ (self.clip_processor, self.clip_model, self.tokenizer, self.text_model, self.image_adapter) = load_models(checkpoint_path, status_callback=self.show_status)
1386
+ self.models_loaded = True
1387
+ QMessageBox.information(self, "Models Loaded", "Models have been loaded successfully.")
1388
+ self.show_status("Models loaded successfully. Ready to caption.", 5000)
1389
+ except Exception as e:
1390
+ self.models_loaded = False
1391
+ QMessageBox.critical(self, "Model Loading Error", f"An error occurred while loading models:\n{e}\n\nCheck console for details.")
1392
+ self.show_status(f"Model loading failed. Check console.", 0)
1393
+ print(f"--- Model Loading Error ---"); import traceback; traceback.print_exc(); print(f"--- End Error Traceback ---")
1394
+ finally:
1395
+ self.progress_bar.hide(); self.progress_bar.setRange(0, 100); self.update_button_states()
1396
+
1397
+ def collect_parameters(self):
1398
+ return (self.caption_type_combo.currentText(), self.caption_length_combo.currentText(),
1399
+ [cb.text() for cb in self.extra_options_checkboxes if cb.isChecked()],
1400
+ self.name_input_line.text(), self.custom_prompt_text.toPlainText())
1401
+
1402
+ def _confirm_overwrite(self, file_path: Path) -> bool:
1403
+ if file_path.exists():
1404
+ reply = QMessageBox.question(self, 'Confirm Overwrite', f"Caption file '{file_path.name}' already exists.\nOverwrite?", QMessageBox.Yes | QMessageBox.No, QMessageBox.No)
1405
+ return reply == QMessageBox.Yes
1406
+ return True
1407
+
1408
+ def _save_caption_to_file(self, image_path: Path, caption: str) -> bool:
1409
+ if not image_path: self.show_status("Error: No image path associated.", 5000); return False
1410
+ caption_file = image_path.with_suffix('.txt')
1411
+ mode = 'a' if self.append_checkbox.isChecked() else 'w'
1412
+ prefix = '\n' if mode == 'a' and caption_file.exists() and caption_file.stat().st_size > 0 else ''
1413
+
1414
+ if mode == 'w' and caption_file.exists() and not self.overwrite_checkbox.isChecked():
1415
+ if not self._confirm_overwrite(caption_file):
1416
+ self.show_status(f"Skipped saving {image_path.name}.", 3000); return False
1417
+ try:
1418
+ with open(caption_file, mode, encoding='utf-8') as f: f.write(f"{prefix}{caption}")
1419
+ self.show_status(f"Caption {'appended to' if mode == 'a' else 'saved to'} {caption_file.name}", 4000); return True
1420
+ except Exception as e:
1421
+ QMessageBox.critical(self, "Save Error", f"Error saving caption for {image_path.name}:\n{e}")
1422
+ self.show_status(f"Error saving caption for {image_path.name}", 5000); print(f"Error saving caption to {caption_file}: {e}"); return False
1423
+
1424
+ def _run_caption_generation(self, image_path: Path):
1425
+ if not self.models_loaded: QMessageBox.warning(self, "Models Not Loaded", "Please load models first."); return None
1426
+ if not image_path or not image_path.exists():
1427
+ QMessageBox.warning(self, "Image Not Found", f"Image file does not exist:\n{image_path}")
1428
+ self.show_status(f"Image not found: {image_path.name if image_path else 'None'}", 5000); return None
1429
+
1430
+ self.show_status(f"Processing: {image_path.name}...", 0); QApplication.processEvents()
1431
+ params = self.collect_parameters()
1432
+ try: input_image = Image.open(image_path)
1433
+ except Exception as e:
1434
+ QMessageBox.critical(self, "Image Open Error", f"Failed to open {image_path.name}:\n{e}")
1435
+ self.show_status(f"Error opening {image_path.name}", 5000); print(f"Error opening image {image_path}: {e}"); return None
1436
+ try:
1437
+ prompt_str, caption = generate_caption(input_image, *params, self.clip_model, self.tokenizer, self.text_model, self.image_adapter)
1438
+ current_viewed_path = self.selected_image_path or self.single_image_path
1439
+ if image_path == current_viewed_path: self.generated_caption_text.setText(caption)
1440
+ if self._save_caption_to_file(image_path, caption): print(f"Caption generated and saved for {image_path.name}")
1441
+ else: print(f"Caption generated but NOT saved for {image_path.name}")
1442
+ return caption
1443
+ except Exception as e:
1444
+ QMessageBox.critical(self, "Processing Error", f"Failed to process {image_path.name}:\n{e}\n\nCheck console.")
1445
+ self.show_status(f"Error processing {image_path.name}. Check console.", 0)
1446
+ print(f"--- Processing Error for {image_path.name} ---"); import traceback; traceback.print_exc(); print(f"--- End Error Traceback ---")
1447
+ current_viewed_path = self.selected_image_path or self.single_image_path
1448
+ if image_path == current_viewed_path: self.generated_caption_text.setText(f"Error generating caption. See console.")
1449
+ return None
1450
+ finally: QApplication.processEvents()
1451
+
1452
+ def generate_captions_action(self):
1453
+ if not self.input_dir or not self.image_files: QMessageBox.warning(self, "No Images", "Select directory with images first."); return
1454
+ if not self.models_loaded: QMessageBox.warning(self, "Models Not Loaded", "Load models first."); return
1455
+
1456
+ num_images = len(self.image_files)
1457
+ self.progress_bar.setRange(0, num_images); self.progress_bar.setValue(0); self.progress_bar.show()
1458
+ self.show_status(f"Starting batch captioning for {num_images} images...", 0)
1459
+
1460
+ processed_count, error_count, skipped_explicitly = 0, 0, 0
1461
+ original_overwrite_state = self.overwrite_checkbox.isChecked() # Remember original state
1462
+ ask_all = False # Flag to check if user agreed to overwrite all
1463
+
1464
+ # Pre-check for overwrites if needed
1465
+ files_to_confirm = []
1466
+ if not self.overwrite_checkbox.isChecked() and not self.append_checkbox.isChecked():
1467
+ files_to_confirm = [img.with_suffix('.txt').name for img in self.image_files if img.with_suffix('.txt').exists()]
1468
+
1469
+ if files_to_confirm:
1470
+ reply = QMessageBox.question(self, 'Confirm Overwrite Multiple', f"{len(files_to_confirm)} existing caption file(s) found.\nOverwrite ALL existing files?", QMessageBox.Yes | QMessageBox.No | QMessageBox.Cancel, QMessageBox.Cancel)
1471
+ if reply == QMessageBox.Cancel: self.show_status("Batch cancelled.", 3000); self.progress_bar.hide(); return
1472
+ elif reply == QMessageBox.Yes: ask_all = True; self.overwrite_checkbox.setChecked(True) # Temporarily check it
1473
+
1474
+ # Process images
1475
+ for i, image_path in enumerate(self.image_files):
1476
+ # Run generation. _save_caption_to_file handles individual confirmation if ask_all is False
1477
+ caption_result = self._run_caption_generation(image_path)
1478
+
1479
+ # Track results (Approximate - relies on _save reporting skips)
1480
+ if caption_result is not None:
1481
+ processed_count += 1
1482
+ else:
1483
+ # If None, assume error unless status bar indicates skip (imperfect)
1484
+ if "Skipped saving" not in self.status_bar.currentMessage():
1485
+ error_count += 1
1486
+ # No reliable way to count skips here without modifying _save return value
1487
+
1488
+ self.progress_bar.setValue(i + 1)
1489
+ QApplication.processEvents()
1490
+
1491
+ # Restore overwrite checkbox state if changed
1492
+ if ask_all: self.overwrite_checkbox.setChecked(original_overwrite_state)
1493
+
1494
+ self.progress_bar.hide()
1495
+ final_message = f"Batch finished. {processed_count} captions generated/saved."
1496
+ if error_count > 0: final_message += f" {error_count} errors."
1497
+ # Cannot reliably report skips here
1498
+ QMessageBox.information(self, "Batch Complete", final_message)
1499
+ self.show_status(final_message, 10000)
1500
+ self.update_button_states()
1501
+
1502
+ def caption_selected_image_action(self):
1503
+ if not self.selected_image_path: QMessageBox.warning(self, "No Image Selected", "Select image from list first."); return
1504
+ self._run_caption_generation(self.selected_image_path); self.update_button_states()
1505
+
1506
+ def caption_single_image_action(self):
1507
+ if not self.single_image_path: QMessageBox.warning(self, "No Image Selected", "Select single image first."); return
1508
+ self._run_caption_generation(self.single_image_path); self.update_button_states()
1509
+
1510
+ def save_edited_caption_action(self):
1511
+ edited_caption = self.generated_caption_text.toPlainText().strip()
1512
+ if not edited_caption: QMessageBox.warning(self, "Empty Caption", "Caption text is empty."); return
1513
+ current_image_path = self.selected_image_path or self.single_image_path
1514
+ if not current_image_path: QMessageBox.warning(self, "No Associated Image", "Select image first."); return
1515
+ self._save_caption_to_file(current_image_path, edited_caption)
1516
+
1517
+ def resizeEvent(self, event):
1518
+ super().resizeEvent(event)
1519
+ current_path = None
1520
+ if self.selected_image_label.pixmap() and not self.selected_image_label.pixmap().isNull():
1521
+ current_path = self.selected_image_path or self.single_image_path
1522
+ if current_path and current_path.exists():
1523
+ try:
1524
+ pixmap = QPixmap(str(current_path))
1525
+ if not pixmap.isNull(): self.scale_and_set_pixmap(pixmap)
1526
+ except Exception as e: print(f"Error reloading pixmap on resize for {current_path}: {e}")
1527
+ elif not self.selected_image_label.text() or self.selected_image_label.text().startswith(("Cannot", "Error", "No image")):
1528
+ self.selected_image_label.clear(); self.selected_image_label.setText("No image selected")
1529
+
1530
+
1531
+ if __name__ == "__main__":
1532
+ QApplication.setAttribute(Qt.AA_EnableHighDpiScaling, True) # Optional
1533
+ QApplication.setAttribute(Qt.AA_UseHighDpiPixmaps, True) # Optional
1534
+ app = QApplication(sys.argv)
1535
+ app.setStyle("Fusion") # Optional
1536
+ window = CaptionApp()
1537
+ window.show()
1538
+ sys.exit(app.exec_())