dominic1021 commited on
Commit
eb470d1
·
verified ·
1 Parent(s): 1aff144

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +440 -0
app.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ import spaces
4
+ import gradio as gr
5
+ from huggingface_hub import InferenceClient
6
+ from torch import nn
7
+ from transformers import AutoModel, AutoProcessor, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast, AutoModelForCausalLM, BitsAndBytesConfig
8
+ import torch
9
+ import torch.amp.autocast_mode
10
+ from PIL import Image
11
+ import torchvision.transforms.functional as TVF
12
+ import gc
13
+ from peft import PeftConfig
14
+ from gradio.themes.utils import colors
15
+ from gradio.themes import Base
16
+
17
+ # Define the base directory
18
+ BASE_DIR = Path(__file__).resolve().parent
19
+
20
+ device = "cuda" if torch.cuda.is_available() else "cpu"
21
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
22
+ use_inference_client = False
23
+
24
+ llm_models = {
25
+ "bunnycore/LLama-3.1-8B-Matrix": None,
26
+ "Sao10K/Llama-3.1-8B-Stheno-v3.4": None,
27
+ "unsloth/Meta-Llama-3.1-8B-bnb-4bit": None,
28
+ "DevQuasar/HermesNova-Llama-3.1-8B": None,
29
+ "mergekit-community/L3.1-Boshima-b-FIX": None,
30
+ "meta-llama/Meta-Llama-3.1-8B": None, # gated
31
+ }
32
+
33
+ CLIP_PATH = "google/siglip-so400m-patch14-384"
34
+ MODEL_PATH = list(llm_models.keys())[0]
35
+ CHECKPOINT_PATH = BASE_DIR / "9em124t2-499968"
36
+ LORA_PATH = CHECKPOINT_PATH / "text_model"
37
+
38
+ JC_TITLE_MD = "<h1><center>JoyCaption Alpha One Mod</center></h1>"
39
+ JC_DESC_MD = """This space is mod of [fancyfeast/joy-caption-alpha-one](https://huggingface.co/spaces/fancyfeast/joy-caption-alpha-one),
40
+ [Wi-zz/joy-caption-pre-alpha](https://huggingface.co/Wi-zz/joy-caption-pre-alpha)"""
41
+
42
+ CAPTION_TYPE_MAP = {
43
+ ("descriptive", "formal", False, False): ["Write a descriptive caption for this image in a formal tone."],
44
+ ("descriptive", "formal", False, True): ["Write a descriptive caption for this image in a formal tone within {word_count} words."],
45
+ ("descriptive", "formal", True, False): ["Write a {length} descriptive caption for this image in a formal tone."],
46
+ ("descriptive", "informal", False, False): ["Write a descriptive caption for this image in a casual tone."],
47
+ ("descriptive", "informal", False, True): ["Write a descriptive caption for this image in a casual tone within {word_count} words."],
48
+ ("descriptive", "informal", True, False): ["Write a {length} descriptive caption for this image in a casual tone."],
49
+ ("training_prompt", "formal", False, False): ["Write a stable diffusion prompt for this image."],
50
+ ("training_prompt", "formal", False, True): ["Write a stable diffusion prompt for this image within {word_count} words."],
51
+ ("training_prompt", "formal", True, False): ["Write a {length} stable diffusion prompt for this image."],
52
+ ("rng-tags", "formal", False, False): ["Write a list of Booru tags for this image."],
53
+ ("rng-tags", "formal", False, True): ["Write a list of Booru tags for this image within {word_count} words."],
54
+ ("rng-tags", "formal", True, False): ["Write a {length} list of Booru tags for this image."],
55
+ }
56
+
57
+ class ImageAdapter(nn.Module):
58
+ def __init__(self, input_features: int, output_features: int, ln1: bool, pos_emb: bool, num_image_tokens: int, deep_extract: bool):
59
+ super().__init__()
60
+ self.deep_extract = deep_extract
61
+
62
+ if self.deep_extract:
63
+ input_features = input_features * 5
64
+
65
+ self.linear1 = nn.Linear(input_features, output_features)
66
+ self.activation = nn.GELU()
67
+ self.linear2 = nn.Linear(output_features, output_features)
68
+ self.ln1 = nn.Identity() if not ln1 else nn.LayerNorm(input_features)
69
+ self.pos_emb = None if not pos_emb else nn.Parameter(torch.zeros(num_image_tokens, input_features))
70
+
71
+ self.other_tokens = nn.Embedding(3, output_features)
72
+ self.other_tokens.weight.data.normal_(mean=0.0, std=0.02)
73
+
74
+ def forward(self, vision_outputs: torch.Tensor):
75
+ if self.deep_extract:
76
+ x = torch.concat((
77
+ vision_outputs[-2],
78
+ vision_outputs[3],
79
+ vision_outputs[7],
80
+ vision_outputs[13],
81
+ vision_outputs[20],
82
+ ), dim=-1)
83
+ assert len(x.shape) == 3, f"Expected 3, got {len(x.shape)}"
84
+ assert x.shape[-1] == vision_outputs[-2].shape[-1] * 5, f"Expected {vision_outputs[-2].shape[-1] * 5}, got {x.shape[-1]}"
85
+ else:
86
+ x = vision_outputs[-2]
87
+
88
+ x = self.ln1(x)
89
+
90
+ if self.pos_emb is not None:
91
+ assert x.shape[-2:] == self.pos_emb.shape, f"Expected {self.pos_emb.shape}, got {x.shape[-2:]}"
92
+ x = x + self.pos_emb
93
+
94
+ x = self.linear1(x)
95
+ x = self.activation(x)
96
+ x = self.linear2(x)
97
+
98
+ other_tokens = self.other_tokens(torch.tensor([0, 1], device=self.other_tokens.weight.device).expand(x.shape[0], -1))
99
+ assert other_tokens.shape == (x.shape[0], 2, x.shape[2]), f"Expected {(x.shape[0], 2, x.shape[2])}, got {other_tokens.shape}"
100
+ x = torch.cat((other_tokens[:, 0:1], x, other_tokens[:, 1:2]), dim=1)
101
+
102
+ return x
103
+
104
+ def get_eot_embedding(self):
105
+ return self.other_tokens(torch.tensor([2], device=self.other_tokens.weight.device)).squeeze(0)
106
+
107
+ tokenizer = None
108
+ text_model_client = None
109
+ text_model = None
110
+ image_adapter = None
111
+ peft_config = None
112
+
113
+ def load_text_model(model_name: str=MODEL_PATH, gguf_file: str | None=None, is_nf4: bool=True):
114
+ global tokenizer, text_model, image_adapter, peft_config, text_model_client, use_inference_client
115
+ try:
116
+ nf4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4",
117
+ bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)
118
+ print("Loading tokenizer")
119
+ if gguf_file:
120
+ tokenizer = AutoTokenizer.from_pretrained(model_name, gguf_file=gguf_file, use_fast=True, legacy=False)
121
+ else:
122
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, legacy=False)
123
+ assert isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast), f"Tokenizer is of type {type(tokenizer)}"
124
+
125
+ print(f"Loading LLM: {model_name}")
126
+ if gguf_file:
127
+ if device == "cpu":
128
+ text_model = AutoModelForCausalLM.from_pretrained(model_name, gguf_file=gguf_file, device_map=device, torch_dtype=torch.bfloat16).eval()
129
+ elif is_nf4:
130
+ text_model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=nf4_config, device_map=device, torch_dtype=torch.bfloat16).eval()
131
+ else:
132
+ text_model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device, torch_dtype=torch.bfloat16).eval()
133
+ else:
134
+ if device == "cpu":
135
+ text_model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device, torch_dtype=torch.bfloat16).eval()
136
+ elif is_nf4:
137
+ text_model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=nf4_config, device_map=device, torch_dtype=torch.bfloat16).eval()
138
+ else:
139
+ text_model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device, torch_dtype=torch.bfloat16).eval()
140
+
141
+ if LORA_PATH.exists():
142
+ print("Loading VLM's custom text model")
143
+ if is_nf4:
144
+ peft_config = PeftConfig.from_pretrained(str(LORA_PATH), device_map=device, quantization_config=nf4_config)
145
+ else:
146
+ peft_config = PeftConfig.from_pretrained(str(LORA_PATH), device_map=device)
147
+ text_model.add_adapter(peft_config)
148
+ text_model.enable_adapters()
149
+
150
+ print("Loading image adapter")
151
+ image_adapter = ImageAdapter(clip_model.config.hidden_size, text_model.config.hidden_size, False, False, 38, False).eval().to("cpu")
152
+ image_adapter_path = CHECKPOINT_PATH / "image_adapter.pt"
153
+ image_adapter.load_state_dict(torch.load(image_adapter_path, map_location="cpu", weights_only=True))
154
+ image_adapter.eval().to(device)
155
+ except Exception as e:
156
+ print(f"LLM load error: {e}")
157
+ raise Exception(f"LLM load error: {e}") from e
158
+ finally:
159
+ torch.cuda.empty_cache()
160
+ gc.collect()
161
+
162
+ load_text_model.zerogpu = True
163
+
164
+ # Load CLIP
165
+ print("Loading CLIP")
166
+ clip_processor = AutoProcessor.from_pretrained(CLIP_PATH)
167
+ clip_model = AutoModel.from_pretrained(CLIP_PATH).vision_model
168
+
169
+ clip_model_path = CHECKPOINT_PATH / "clip_model.pt"
170
+ if clip_model_path.exists():
171
+ print("Loading VLM's custom vision model")
172
+ checkpoint = torch.load(clip_model_path, map_location='cpu')
173
+ checkpoint = {k.replace("_orig_mod.module.", ""): v for k, v in checkpoint.items()}
174
+ clip_model.load_state_dict(checkpoint)
175
+ del checkpoint
176
+
177
+ clip_model.eval().requires_grad_(False).to(device)
178
+
179
+ # Load text model
180
+ load_text_model()
181
+
182
+ @spaces.GPU()
183
+ @torch.no_grad()
184
+ def stream_chat_mod(input_image: Image.Image, caption_type: str, caption_tone: str, caption_length: str | int, max_new_tokens: int=300, top_p: float=0.9, temperature: float=0.6, progress=gr.Progress(track_tqdm=True)) -> str:
185
+ global use_inference_client
186
+ global text_model
187
+ torch.cuda.empty_cache()
188
+ gc.collect()
189
+
190
+ length = None if caption_length == "any" else caption_length
191
+
192
+ if isinstance(length, str):
193
+ try:
194
+ length = int(length)
195
+ except ValueError:
196
+ pass
197
+
198
+ if caption_type == "rng-tags" or caption_type == "training_prompt":
199
+ caption_tone = "formal"
200
+
201
+ prompt_key = (caption_type, caption_tone, isinstance(length, str), isinstance(length, int))
202
+ if prompt_key not in CAPTION_TYPE_MAP:
203
+ raise ValueError(f"Invalid caption type: {prompt_key}")
204
+
205
+ prompt_str = CAPTION_TYPE_MAP[prompt_key][0].format(length=length, word_count=length)
206
+ print(f"Prompt: {prompt_str}")
207
+
208
+ image = input_image.resize((384, 384), Image.LANCZOS)
209
+ pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
210
+ pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
211
+ pixel_values = pixel_values.to(device)
212
+
213
+ prompt = tokenizer.encode(prompt_str, return_tensors='pt', padding=False, truncation=False, add_special_tokens=False)
214
+
215
+ with torch.amp.autocast_mode.autocast(device, enabled=True):
216
+ vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True)
217
+ image_features = vision_outputs.hidden_states
218
+ embedded_images = image_adapter(image_features)
219
+ embedded_images = embedded_images.to(device)
220
+
221
+ prompt_embeds = text_model.model.embed_tokens(prompt.to(device))
222
+ assert prompt_embeds.shape == (1, prompt.shape[1], text_model.config.hidden_size), f"Prompt shape is {prompt_embeds.shape}, expected {(1, prompt.shape[1], text_model.config.hidden_size)}"
223
+ embedded_bos = text_model.model.embed_tokens(torch.tensor([[tokenizer.bos_token_id]], device=text_model.device, dtype=torch.int64))
224
+ eot_embed = image_adapter.get_eot_embedding().unsqueeze(0).to(dtype=text_model.dtype)
225
+
226
+ inputs_embeds = torch.cat([
227
+ embedded_bos.expand(embedded_images.shape[0], -1, -1),
228
+ embedded_images.to(dtype=embedded_bos.dtype),
229
+ prompt_embeds.expand(embedded_images.shape[0], -1, -1),
230
+ eot_embed.expand(embedded_images.shape[0], -1, -1),
231
+ ], dim=1)
232
+
233
+ input_ids = torch.cat([
234
+ torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long),
235
+ torch.zeros((1, embedded_images.shape[1]), dtype=torch.long),
236
+ prompt,
237
+ torch.tensor([[tokenizer.convert_tokens_to_ids("<|eot_id|>")]], dtype=torch.long),
238
+ ], dim=1).to(device)
239
+ attention_mask = torch.ones_like(input_ids)
240
+
241
+ text_model.to(device)
242
+ generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=max_new_tokens,
243
+ do_sample=True, suppress_tokens=None, top_p=top_p, temperature=temperature)
244
+
245
+ generate_ids = generate_ids[:, input_ids.shape[1]:]
246
+ if generate_ids[0][-1] == tokenizer.eos_token_id or generate_ids[0][-1] == tokenizer.convert_tokens_to_ids("<|eot_id|>"):
247
+ generate_ids = generate_ids[:, :-1]
248
+
249
+ caption = tokenizer.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0]
250
+
251
+ return caption.strip()
252
+
253
+ def is_repo_name(s):
254
+ import re
255
+ return re.fullmatch(r'^[^/,\s\"\']+/[^/,\s\"\']+$', s)
256
+
257
+ def is_repo_exists(repo_id):
258
+ from huggingface_hub import HfApi
259
+ try:
260
+ api = HfApi(token=HF_TOKEN)
261
+ return api.repo_exists(repo_id=repo_id)
262
+ except Exception as e:
263
+ print(f"Error: Failed to connect {repo_id}.")
264
+ print(e)
265
+ return True # for safety
266
+
267
+ def get_text_model():
268
+ return list(llm_models.keys())
269
+
270
+ def is_gguf_repo(repo_id: str):
271
+ from huggingface_hub import HfApi
272
+ try:
273
+ api = HfApi(token=HF_TOKEN)
274
+ if not is_repo_name(repo_id) or not is_repo_exists(repo_id):
275
+ return False
276
+ files = api.list_repo_files(repo_id=repo_id)
277
+ except Exception as e:
278
+ print(f"Error: Failed to get {repo_id}'s info.")
279
+ print(e)
280
+ gr.Warning(f"Error: Failed to get {repo_id}'s info.")
281
+ return False
282
+ files = [f for f in files if f.endswith(".gguf")]
283
+ return len(files) > 0
284
+
285
+ def get_repo_gguf(repo_id: str):
286
+ from huggingface_hub import HfApi
287
+ try:
288
+ api = HfApi(token=HF_TOKEN)
289
+ if not is_repo_name(repo_id) or not is_repo_exists(repo_id):
290
+ return gr.update(value="", choices=[])
291
+ files = api.list_repo_files(repo_id=repo_id)
292
+ except Exception as e:
293
+ print(f"Error: Failed to get {repo_id}'s info.")
294
+ print(e)
295
+ gr.Warning(f"Error: Failed to get {repo_id}'s info.")
296
+ return gr.update(value="", choices=[])
297
+ files = [f for f in files if f.endswith(".gguf")]
298
+ if len(files) == 0:
299
+ return gr.update(value="", choices=[])
300
+ else:
301
+ return gr.update(value=files[0], choices=files)
302
+
303
+ @spaces.GPU()
304
+ def change_text_model(model_name: str=MODEL_PATH, use_client: bool=False, gguf_file: str | None=None,
305
+ is_nf4: bool=True, progress=gr.Progress(track_tqdm=True)):
306
+ global use_inference_client, llm_models
307
+ use_inference_client = use_client
308
+ try:
309
+ if not is_repo_name(model_name) or not is_repo_exists(model_name):
310
+ raise gr.Error(f"Repo doesn't exist: {model_name}")
311
+ if not gguf_file and is_gguf_repo(model_name):
312
+ gr.Info(f"Please select a gguf file.")
313
+ return gr.update(visible=True)
314
+ if not use_inference_client:
315
+ load_text_model(model_name, gguf_file, is_nf4)
316
+ if model_name not in llm_models:
317
+ llm_models[model_name] = gguf_file if gguf_file else None
318
+ return gr.update(choices=get_text_model())
319
+ except Exception as e:
320
+ raise gr.Error(f"Model load error: {model_name}, {e}")
321
+
322
+ # Define a custom theme
323
+ class NeonPurpleTheme(Base):
324
+ def __init__(self):
325
+ super().__init__(
326
+ primary_hue="purple",
327
+ secondary_hue="violet",
328
+ neutral_hue="slate",
329
+ font=("Roboto", "sans-serif"),
330
+ )
331
+ self.set(
332
+ body_background_fill="linear-gradient(to right, #2E0854, #5B0E91)",
333
+ body_background_fill_dark="linear-gradient(to right, #2E0854, #5B0E91)",
334
+ body_text_color="#FFFFFF",
335
+ body_text_color_dark="#FFFFFF",
336
+ button_primary_background_fill="#8A2BE2",
337
+ button_primary_background_fill_hover="#9B30FF",
338
+ button_primary_text_color="#FFFFFF",
339
+ button_secondary_background_fill="#4B0082",
340
+ button_secondary_background_fill_hover="#5D478B",
341
+ button_secondary_text_color="#FFFFFF",
342
+ background_fill_primary="#3C1361",
343
+ background_fill_secondary="#4B0082",
344
+ border_color_primary="#8A2BE2",
345
+ block_title_text_color="#FFD700",
346
+ block_label_text_color="#E6E6FA",
347
+ input_background_fill="#2F0147",
348
+ input_border_color="#8A2BE2",
349
+ input_placeholder_color="#B19CD9",
350
+ slider_color="#8A2BE2",
351
+ slider_thumb_color="#FFD700",
352
+ checkbox_background_color="#4B0082",
353
+ checkbox_border_color="#8A2BE2",
354
+ checkbox_check_color="#FFD700",
355
+ )
356
+
357
+ # Update the CSS
358
+ css = """
359
+ .info {
360
+ text-align: center !important;
361
+ color: #E6E6FA;
362
+ font-size: 1.1em;
363
+ margin-top: 20px;
364
+ }
365
+ .gradio-container {
366
+ max-width: 1200px !important;
367
+ margin: auto;
368
+ }
369
+ .gr-button {
370
+ font-weight: bold;
371
+ }
372
+ .gr-form {
373
+ border-radius: 15px;
374
+ padding: 20px;
375
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
376
+ }
377
+ .gr-box {
378
+ border-radius: 15px;
379
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
380
+ }
381
+ .gr-padded {
382
+ padding: 20px;
383
+ }
384
+ """
385
+
386
+ # Update the Gradio interface
387
+ with gr.Blocks(theme=NeonPurpleTheme(), css=css) as demo:
388
+ gr.HTML(f"<h1 style='text-align: center; color: #FFD700;'>JoyCaption Alpha One Mod</h1>")
389
+ with gr.Row():
390
+ with gr.Column(scale=1):
391
+ with gr.Group(elem_classes="gr-form"):
392
+ jc_input_image = gr.Image(type="pil", label="Input Image", sources=["upload", "clipboard"], height=384)
393
+ with gr.Row():
394
+ jc_caption_type = gr.Dropdown(
395
+ choices=["descriptive", "training_prompt", "rng-tags"],
396
+ label="Caption Type",
397
+ value="descriptive",
398
+ )
399
+ jc_caption_tone = gr.Dropdown(
400
+ choices=["formal", "informal"],
401
+ label="Caption Tone",
402
+ value="formal",
403
+ )
404
+ jc_caption_length = gr.Dropdown(
405
+ choices=["any", "very short", "short", "medium-length", "long", "very long"] +
406
+ [str(i) for i in range(20, 261, 10)],
407
+ label="Caption Length",
408
+ value="any",
409
+ )
410
+ gr.Markdown("**Note:** Caption tone doesn't affect `rng-tags` and `training_prompt`.", elem_classes="info")
411
+ with gr.Accordion("Advanced", open=False):
412
+ with gr.Row():
413
+ jc_text_model = gr.Dropdown(label="LLM Model", info="You can enter a huggingface model repo_id to want to use.",
414
+ choices=get_text_model(), value=get_text_model()[0],
415
+ allow_custom_value=True, interactive=True, min_width=320)
416
+ jc_gguf = gr.Dropdown(label=f"GGUF Filename", choices=[], value="",
417
+ allow_custom_value=True, min_width=320, visible=False)
418
+ jc_nf4 = gr.Checkbox(label="Use NF4 quantization", value=True)
419
+ jc_text_model_button = gr.Button("Load Model", variant="secondary")
420
+ jc_use_inference_client = gr.Checkbox(label="Use Inference Client", value=False, visible=False)
421
+ with gr.Row():
422
+ jc_tokens = gr.Slider(minimum=1, maximum=4096, value=300, step=1, label="Max tokens")
423
+ jc_temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.6, step=0.1, label="Temperature")
424
+ jc_topp = gr.Slider(minimum=0, maximum=2.0, value=0.9, step=0.01, label="Top-P")
425
+ jc_run_button = gr.Button("Generate Caption", variant="primary", size="lg")
426
+
427
+ with gr.Column(scale=1):
428
+ jc_output_caption = gr.Textbox(label="Generated Caption", show_copy_button=True, elem_classes="gr-box gr-padded")
429
+
430
+ gr.Markdown(JC_DESC_MD, elem_classes="info")
431
+ with gr.Row():
432
+ gr.LoginButton()
433
+ gr.DuplicateButton(value="Duplicate Space for private use", variant="secondary")
434
+
435
+ jc_run_button.click(fn=stream_chat_mod, inputs=[jc_input_image, jc_caption_type, jc_caption_tone, jc_caption_length, jc_tokens, jc_topp, jc_temperature], outputs=[jc_output_caption])
436
+ jc_text_model_button.click(change_text_model, inputs=[jc_text_model, jc_use_inference_client, jc_gguf, jc_nf4], outputs=[jc_text_model])
437
+ jc_use_inference_client.change(change_text_model, inputs=[jc_text_model, jc_use_inference_client], outputs=[jc_text_model])
438
+
439
+ if __name__ == "__main__":
440
+ demo.launch(share=True)