| | import os |
| | from pathlib import Path |
| | import spaces |
| | import gradio as gr |
| | from huggingface_hub import InferenceClient |
| | from torch import nn |
| | from transformers import AutoModel, AutoProcessor, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast, AutoModelForCausalLM, BitsAndBytesConfig |
| | import torch |
| | import torch.amp.autocast_mode |
| | from PIL import Image |
| | import torchvision.transforms.functional as TVF |
| | import gc |
| | from peft import PeftConfig |
| |
|
| | |
| | BASE_DIR = Path(__file__).resolve().parent |
| |
|
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | HF_TOKEN = os.environ.get("HF_TOKEN", None) |
| | use_inference_client = False |
| |
|
| | llm_models = { |
| | "bunnycore/LLama-3.1-8B-Matrix": None, |
| | "Sao10K/Llama-3.1-8B-Stheno-v3.4": None, |
| | "unsloth/Meta-Llama-3.1-8B-bnb-4bit": None, |
| | "DevQuasar/HermesNova-Llama-3.1-8B": None, |
| | "mergekit-community/L3.1-Boshima-b-FIX": None, |
| | "meta-llama/Meta-Llama-3.1-8B": None, |
| | } |
| |
|
| | CLIP_PATH = "google/siglip-so400m-patch14-384" |
| | MODEL_PATH = list(llm_models.keys())[0] |
| | CHECKPOINT_PATH = BASE_DIR / "9em124t2-499968" |
| | LORA_PATH = CHECKPOINT_PATH / "text_model" |
| |
|
| | JC_TITLE_MD = "<h1><center>JoyCaption Alpha One Mod</center></h1>" |
| | JC_DESC_MD = """This space is mod of [fancyfeast/joy-caption-alpha-one](https://huggingface.co/spaces/fancyfeast/joy-caption-alpha-one), |
| | [Wi-zz/joy-caption-pre-alpha](https://huggingface.co/Wi-zz/joy-caption-pre-alpha)""" |
| |
|
| | CAPTION_TYPE_MAP = { |
| | ("descriptive", "formal", False, False): ["Write a descriptive caption for this image in a formal tone."], |
| | ("descriptive", "formal", False, True): ["Write a descriptive caption for this image in a formal tone within {word_count} words."], |
| | ("descriptive", "formal", True, False): ["Write a {length} descriptive caption for this image in a formal tone."], |
| | ("descriptive", "informal", False, False): ["Write a descriptive caption for this image in a casual tone."], |
| | ("descriptive", "informal", False, True): ["Write a descriptive caption for this image in a casual tone within {word_count} words."], |
| | ("descriptive", "informal", True, False): ["Write a {length} descriptive caption for this image in a casual tone."], |
| | ("training_prompt", "formal", False, False): ["Write a stable diffusion prompt for this image."], |
| | ("training_prompt", "formal", False, True): ["Write a stable diffusion prompt for this image within {word_count} words."], |
| | ("training_prompt", "formal", True, False): ["Write a {length} stable diffusion prompt for this image."], |
| | ("rng-tags", "formal", False, False): ["Write a list of Booru tags for this image."], |
| | ("rng-tags", "formal", False, True): ["Write a list of Booru tags for this image within {word_count} words."], |
| | ("rng-tags", "formal", True, False): ["Write a {length} list of Booru tags for this image."], |
| | } |
| |
|
| | class ImageAdapter(nn.Module): |
| | def __init__(self, input_features: int, output_features: int, ln1: bool, pos_emb: bool, num_image_tokens: int, deep_extract: bool): |
| | super().__init__() |
| | self.deep_extract = deep_extract |
| |
|
| | if self.deep_extract: |
| | input_features = input_features * 5 |
| |
|
| | self.linear1 = nn.Linear(input_features, output_features) |
| | self.activation = nn.GELU() |
| | self.linear2 = nn.Linear(output_features, output_features) |
| | self.ln1 = nn.Identity() if not ln1 else nn.LayerNorm(input_features) |
| | self.pos_emb = None if not pos_emb else nn.Parameter(torch.zeros(num_image_tokens, input_features)) |
| |
|
| | self.other_tokens = nn.Embedding(3, output_features) |
| | self.other_tokens.weight.data.normal_(mean=0.0, std=0.02) |
| |
|
| | def forward(self, vision_outputs: torch.Tensor): |
| | if self.deep_extract: |
| | x = torch.concat(( |
| | vision_outputs[-2], |
| | vision_outputs[3], |
| | vision_outputs[7], |
| | vision_outputs[13], |
| | vision_outputs[20], |
| | ), dim=-1) |
| | assert len(x.shape) == 3, f"Expected 3, got {len(x.shape)}" |
| | assert x.shape[-1] == vision_outputs[-2].shape[-1] * 5, f"Expected {vision_outputs[-2].shape[-1] * 5}, got {x.shape[-1]}" |
| | else: |
| | x = vision_outputs[-2] |
| |
|
| | x = self.ln1(x) |
| |
|
| | if self.pos_emb is not None: |
| | assert x.shape[-2:] == self.pos_emb.shape, f"Expected {self.pos_emb.shape}, got {x.shape[-2:]}" |
| | x = x + self.pos_emb |
| |
|
| | x = self.linear1(x) |
| | x = self.activation(x) |
| | x = self.linear2(x) |
| |
|
| | other_tokens = self.other_tokens(torch.tensor([0, 1], device=self.other_tokens.weight.device).expand(x.shape[0], -1)) |
| | assert other_tokens.shape == (x.shape[0], 2, x.shape[2]), f"Expected {(x.shape[0], 2, x.shape[2])}, got {other_tokens.shape}" |
| | x = torch.cat((other_tokens[:, 0:1], x, other_tokens[:, 1:2]), dim=1) |
| |
|
| | return x |
| |
|
| | def get_eot_embedding(self): |
| | return self.other_tokens(torch.tensor([2], device=self.other_tokens.weight.device)).squeeze(0) |
| |
|
| | tokenizer = None |
| | text_model_client = None |
| | text_model = None |
| | image_adapter = None |
| | peft_config = None |
| |
|
| | def load_text_model(model_name: str=MODEL_PATH, gguf_file: str | None=None, is_nf4: bool=True): |
| | global tokenizer, text_model, image_adapter, peft_config, text_model_client, use_inference_client |
| | try: |
| | nf4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", |
| | bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16) |
| | print("Loading tokenizer") |
| | if gguf_file: |
| | tokenizer = AutoTokenizer.from_pretrained(model_name, gguf_file=gguf_file, use_fast=True, legacy=False) |
| | else: |
| | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, legacy=False) |
| | assert isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast), f"Tokenizer is of type {type(tokenizer)}" |
| | |
| | print(f"Loading LLM: {model_name}") |
| | if gguf_file: |
| | if device == "cpu": |
| | text_model = AutoModelForCausalLM.from_pretrained(model_name, gguf_file=gguf_file, device_map=device, torch_dtype=torch.bfloat16).eval() |
| | elif is_nf4: |
| | text_model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=nf4_config, device_map=device, torch_dtype=torch.bfloat16).eval() |
| | else: |
| | text_model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device, torch_dtype=torch.bfloat16).eval() |
| | else: |
| | if device == "cpu": |
| | text_model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device, torch_dtype=torch.bfloat16).eval() |
| | elif is_nf4: |
| | text_model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=nf4_config, device_map=device, torch_dtype=torch.bfloat16).eval() |
| | else: |
| | text_model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device, torch_dtype=torch.bfloat16).eval() |
| | |
| | if LORA_PATH.exists(): |
| | print("Loading VLM's custom text model") |
| | if is_nf4: |
| | peft_config = PeftConfig.from_pretrained(str(LORA_PATH), device_map=device, quantization_config=nf4_config) |
| | else: |
| | peft_config = PeftConfig.from_pretrained(str(LORA_PATH), device_map=device) |
| | text_model.add_adapter(peft_config) |
| | text_model.enable_adapters() |
| | |
| | print("Loading image adapter") |
| | image_adapter = ImageAdapter(clip_model.config.hidden_size, text_model.config.hidden_size, False, False, 38, False).eval().to("cpu") |
| | image_adapter_path = CHECKPOINT_PATH / "image_adapter.pt" |
| | image_adapter.load_state_dict(torch.load(image_adapter_path, map_location="cpu", weights_only=True)) |
| | image_adapter.eval().to(device) |
| | except Exception as e: |
| | print(f"LLM load error: {e}") |
| | raise Exception(f"LLM load error: {e}") from e |
| | finally: |
| | torch.cuda.empty_cache() |
| | gc.collect() |
| |
|
| | load_text_model.zerogpu = True |
| |
|
| | |
| | print("Loading CLIP") |
| | clip_processor = AutoProcessor.from_pretrained(CLIP_PATH) |
| | clip_model = AutoModel.from_pretrained(CLIP_PATH).vision_model |
| |
|
| | clip_model_path = CHECKPOINT_PATH / "clip_model.pt" |
| | if clip_model_path.exists(): |
| | print("Loading VLM's custom vision model") |
| | checkpoint = torch.load(clip_model_path, map_location='cpu') |
| | checkpoint = {k.replace("_orig_mod.module.", ""): v for k, v in checkpoint.items()} |
| | clip_model.load_state_dict(checkpoint) |
| | del checkpoint |
| |
|
| | clip_model.eval().requires_grad_(False).to(device) |
| |
|
| | |
| | load_text_model() |
| |
|
| | @spaces.GPU() |
| | @torch.no_grad() |
| | def stream_chat_mod(input_image: Image.Image, caption_type: str, caption_tone: str, caption_length: str | int, max_new_tokens: int=300, top_p: float=0.9, temperature: float=0.6, progress=gr.Progress(track_tqdm=True)) -> str: |
| | global use_inference_client |
| | global text_model |
| | torch.cuda.empty_cache() |
| | gc.collect() |
| |
|
| | length = None if caption_length == "any" else caption_length |
| |
|
| | if isinstance(length, str): |
| | try: |
| | length = int(length) |
| | except ValueError: |
| | pass |
| |
|
| | if caption_type == "rng-tags" or caption_type == "training_prompt": |
| | caption_tone = "formal" |
| |
|
| | prompt_key = (caption_type, caption_tone, isinstance(length, str), isinstance(length, int)) |
| | if prompt_key not in CAPTION_TYPE_MAP: |
| | raise ValueError(f"Invalid caption type: {prompt_key}") |
| |
|
| | prompt_str = CAPTION_TYPE_MAP[prompt_key][0].format(length=length, word_count=length) |
| | print(f"Prompt: {prompt_str}") |
| |
|
| | image = input_image.resize((384, 384), Image.LANCZOS) |
| | pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0 |
| | pixel_values = TVF.normalize(pixel_values, [0.5], [0.5]) |
| | pixel_values = pixel_values.to(device) |
| |
|
| | prompt = tokenizer.encode(prompt_str, return_tensors='pt', padding=False, truncation=False, add_special_tokens=False) |
| |
|
| | with torch.amp.autocast_mode.autocast(device, enabled=True): |
| | vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True) |
| | image_features = vision_outputs.hidden_states |
| | embedded_images = image_adapter(image_features) |
| | embedded_images = embedded_images.to(device) |
| | |
| | prompt_embeds = text_model.model.embed_tokens(prompt.to(device)) |
| | assert prompt_embeds.shape == (1, prompt.shape[1], text_model.config.hidden_size), f"Prompt shape is {prompt_embeds.shape}, expected {(1, prompt.shape[1], text_model.config.hidden_size)}" |
| | embedded_bos = text_model.model.embed_tokens(torch.tensor([[tokenizer.bos_token_id]], device=text_model.device, dtype=torch.int64)) |
| | eot_embed = image_adapter.get_eot_embedding().unsqueeze(0).to(dtype=text_model.dtype) |
| |
|
| | inputs_embeds = torch.cat([ |
| | embedded_bos.expand(embedded_images.shape[0], -1, -1), |
| | embedded_images.to(dtype=embedded_bos.dtype), |
| | prompt_embeds.expand(embedded_images.shape[0], -1, -1), |
| | eot_embed.expand(embedded_images.shape[0], -1, -1), |
| | ], dim=1) |
| |
|
| | input_ids = torch.cat([ |
| | torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long), |
| | torch.zeros((1, embedded_images.shape[1]), dtype=torch.long), |
| | prompt, |
| | torch.tensor([[tokenizer.convert_tokens_to_ids("<|eot_id|>")]], dtype=torch.long), |
| | ], dim=1).to(device) |
| | attention_mask = torch.ones_like(input_ids) |
| |
|
| | text_model.to(device) |
| | generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=max_new_tokens, |
| | do_sample=True, suppress_tokens=None, top_p=top_p, temperature=temperature) |
| |
|
| | generate_ids = generate_ids[:, input_ids.shape[1]:] |
| | if generate_ids[0][-1] == tokenizer.eos_token_id or generate_ids[0][-1] == tokenizer.convert_tokens_to_ids("<|eot_id|>"): |
| | generate_ids = generate_ids[:, :-1] |
| |
|
| | caption = tokenizer.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0] |
| |
|
| | return caption.strip() |
| |
|
| | def is_repo_name(s): |
| | import re |
| | return re.fullmatch(r'^[^/,\s\"\']+/[^/,\s\"\']+$', s) |
| |
|
| | def is_repo_exists(repo_id): |
| | from huggingface_hub import HfApi |
| | try: |
| | api = HfApi(token=HF_TOKEN) |
| | return api.repo_exists(repo_id=repo_id) |
| | except Exception as e: |
| | print(f"Error: Failed to connect {repo_id}.") |
| | print(e) |
| | return True |
| |
|
| | def get_text_model(): |
| | return list(llm_models.keys()) |
| |
|
| | def is_gguf_repo(repo_id: str): |
| | from huggingface_hub import HfApi |
| | try: |
| | api = HfApi(token=HF_TOKEN) |
| | if not is_repo_name(repo_id) or not is_repo_exists(repo_id): |
| | return False |
| | files = api.list_repo_files(repo_id=repo_id) |
| | except Exception as e: |
| | print(f"Error: Failed to get {repo_id}'s info.") |
| | print(e) |
| | gr.Warning(f"Error: Failed to get {repo_id}'s info.") |
| | return False |
| | files = [f for f in files if f.endswith(".gguf")] |
| | return len(files) > 0 |
| |
|
| | def get_repo_gguf(repo_id: str): |
| | from huggingface_hub import HfApi |
| | try: |
| | api = HfApi(token=HF_TOKEN) |
| | if not is_repo_name(repo_id) or not is_repo_exists(repo_id): |
| | return gr.update(value="", choices=[]) |
| | files = api.list_repo_files(repo_id=repo_id) |
| | except Exception as e: |
| | print(f"Error: Failed to get {repo_id}'s info.") |
| | print(e) |
| | gr.Warning(f"Error: Failed to get {repo_id}'s info.") |
| | return gr.update(value="", choices=[]) |
| | files = [f for f in files if f.endswith(".gguf")] |
| | if len(files) == 0: |
| | return gr.update(value="", choices=[]) |
| | else: |
| | return gr.update(value=files[0], choices=files) |
| |
|
| | @spaces.GPU() |
| | def change_text_model(model_name: str=MODEL_PATH, use_client: bool=False, gguf_file: str | None=None, |
| | is_nf4: bool=True, progress=gr.Progress(track_tqdm=True)): |
| | global use_inference_client, llm_models |
| | use_inference_client = use_client |
| | try: |
| | if not is_repo_name(model_name) or not is_repo_exists(model_name): |
| | raise gr.Error(f"Repo doesn't exist: {model_name}") |
| | if not gguf_file and is_gguf_repo(model_name): |
| | gr.Info(f"Please select a gguf file.") |
| | return gr.update(visible=True) |
| | if not use_inference_client: |
| | load_text_model(model_name, gguf_file, is_nf4) |
| | if model_name not in llm_models: |
| | llm_models[model_name] = gguf_file if gguf_file else None |
| | return gr.update(choices=get_text_model()) |
| | except Exception as e: |
| | raise gr.Error(f"Model load error: {model_name}, {e}") |
| |
|
| | |
| | css = """ |
| | body { |
| | background: linear-gradient(45deg, #1a0033, #4d0099); |
| | color: #e6ccff; |
| | font-family: 'Arial', sans-serif; |
| | } |
| | .gradio-container { |
| | max-width: 1200px !important; |
| | margin: auto; |
| | } |
| | .gr-button { |
| | background: linear-gradient(90deg, #8a2be2, #9400d3) !important; |
| | border: none !important; |
| | color: white !important; |
| | font-weight: bold; |
| | transition: all 0.3s ease; |
| | } |
| | .gr-button:hover { |
| | background: linear-gradient(90deg, #9400d3, #8a2be2) !important; |
| | box-shadow: 0 0 15px #9400d3; |
| | } |
| | .gr-form { |
| | border-radius: 15px; |
| | padding: 20px; |
| | background-color: rgba(60, 19, 97, 0.7) !important; |
| | box-shadow: 0 0 20px rgba(138, 43, 226, 0.4); |
| | backdrop-filter: blur(10px); |
| | } |
| | .gr-box { |
| | border-radius: 15px; |
| | background-color: rgba(75, 0, 130, 0.7) !important; |
| | box-shadow: 0 0 20px rgba(138, 43, 226, 0.4); |
| | backdrop-filter: blur(5px); |
| | } |
| | .gr-padded { |
| | padding: 20px; |
| | } |
| | .gr-form label, .gr-form .label-wrap { |
| | color: #e6ccff !important; |
| | font-weight: bold; |
| | } |
| | .gr-input, .gr-dropdown { |
| | background-color: rgba(47, 1, 71, 0.8) !important; |
| | border: 2px solid #8a2be2 !important; |
| | color: #ffffff !important; |
| | border-radius: 8px; |
| | } |
| | .gr-input::placeholder { |
| | color: #b19cd9 !important; |
| | } |
| | .gr-checkbox { |
| | background-color: #4b0082 !important; |
| | border-color: #8a2be2 !important; |
| | } |
| | .gr-checkbox:checked { |
| | background-color: #8a2be2 !important; |
| | } |
| | h1, h2, h3 { |
| | color: #ffd700 !important; |
| | text-shadow: 0 0 10px rgba(255, 215, 0, 0.5); |
| | } |
| | .gr-block { |
| | border: none !important; |
| | } |
| | .gr-accordion { |
| | border: 2px solid #8a2be2; |
| | border-radius: 10px; |
| | overflow: hidden; |
| | } |
| | .gr-accordion summary { |
| | background-color: rgba(75, 0, 130, 0.9); |
| | color: #ffd700; |
| | padding: 10px; |
| | font-weight: bold; |
| | cursor: pointer; |
| | } |
| | """ |
| |
|
| | |
| | with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: |
| | gr.HTML( |
| | "<h1 style='text-align: center; color: #FFD700; text-shadow: 0 0 10px rgba(255, 215, 0, 0.5);'>JoyCaption Alpha One Mod</h1>" |
| | "<p style='text-align: center; color: #e6ccff;'>Generate captivating captions for your images!</p>" |
| | ) |
| | with gr.Row(): |
| | with gr.Column(scale=1): |
| | with gr.Group(): |
| | jc_input_image = gr.Image(type="pil", label="Input Image", sources=["upload", "clipboard"], height=384) |
| | with gr.Row(): |
| | jc_caption_type = gr.Dropdown( |
| | choices=["descriptive", "training_prompt", "rng-tags"], |
| | label="Caption Type", |
| | value="descriptive", |
| | ) |
| | jc_caption_tone = gr.Dropdown( |
| | choices=["formal", "informal"], |
| | label="Caption Tone", |
| | value="formal", |
| | ) |
| | jc_caption_length = gr.Dropdown( |
| | choices=["any", "very short", "short", "medium-length", "long", "very long"] + |
| | [str(i) for i in range(20, 261, 10)], |
| | label="Caption Length", |
| | value="any", |
| | ) |
| | gr.Markdown("**Note:** Caption tone doesn't affect `rng-tags` and `training_prompt`.") |
| | with gr.Accordion("Advanced Settings", open=False): |
| | with gr.Row(): |
| | jc_text_model = gr.Dropdown(label="LLM Model", info="You can enter a Hugging Face model repo_id to use.", |
| | choices=get_text_model(), value=get_text_model()[0], |
| | allow_custom_value=True, interactive=True, min_width=320) |
| | jc_gguf = gr.Dropdown(label=f"GGUF Filename", choices=[], value="", |
| | allow_custom_value=True, min_width=320, visible=False) |
| | jc_nf4 = gr.Checkbox(label="Use NF4 quantization", value=True) |
| | jc_text_model_button = gr.Button("Load Model", variant="secondary") |
| | jc_use_inference_client = gr.Checkbox(label="Use Inference Client", value=False, visible=False) |
| | with gr.Row(): |
| | jc_tokens = gr.Slider(minimum=1, maximum=4096, value=300, step=1, label="Max tokens") |
| | jc_temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.6, step=0.1, label="Temperature") |
| | jc_topp = gr.Slider(minimum=0, maximum=2.0, value=0.9, step=0.01, label="Top-P") |
| | jc_run_button = gr.Button("Generate Caption", variant="primary") |
| |
|
| | with gr.Column(scale=1): |
| | jc_output_caption = gr.Textbox(label="Generated Caption", show_copy_button=True) |
| | |
| | gr.Markdown(JC_DESC_MD) |
| | with gr.Row(): |
| | gr.LoginButton() |
| | gr.DuplicateButton(value="Duplicate Space for private use", variant="secondary") |
| |
|
| | jc_run_button.click(fn=stream_chat_mod, inputs=[jc_input_image, jc_caption_type, jc_caption_tone, jc_caption_length, jc_tokens, jc_topp, jc_temperature], outputs=[jc_output_caption]) |
| | jc_text_model_button.click(change_text_model, inputs=[jc_text_model, jc_use_inference_client, jc_gguf, jc_nf4], outputs=[jc_text_model]) |
| | jc_use_inference_client.change(change_text_model, inputs=[jc_text_model, jc_use_inference_client], outputs=[jc_text_model]) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch(share=True) |