Upload 2 files
Browse files- app.py +2 -2
- joycaption.py +21 -6
app.py
CHANGED
|
@@ -49,14 +49,14 @@ with gr.Blocks(fill_width=True, css=css, delete_cache=(60, 3600)) as demo:
|
|
| 49 |
jc_temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.6, step=0.1, label="Temperature")
|
| 50 |
jc_topp = gr.Slider(minimum=0, maximum=2.0, value=0.9, step=0.01, label="Top-P")
|
| 51 |
jc_run_button = gr.Button("Caption", variant="primary")
|
| 52 |
-
|
| 53 |
with gr.Column():
|
| 54 |
jc_output_caption = gr.Textbox(label="Caption", show_copy_button=True)
|
| 55 |
gr.Markdown(JC_DESC_MD, elem_classes="info")
|
| 56 |
gr.LoginButton()
|
| 57 |
gr.DuplicateButton(value="Duplicate Space for private use (This demo does not work on CPU. Requires GPU Space)")
|
| 58 |
|
| 59 |
-
jc_run_button.click(fn=stream_chat_mod, inputs=[jc_input_image, jc_caption_type, jc_caption_tone, jc_caption_length,
|
|
|
|
| 60 |
jc_text_model_button.click(change_text_model, [jc_text_model, jc_use_inference_client, jc_gguf, jc_nf4], [jc_text_model], show_api=False)
|
| 61 |
#jc_text_model.change(get_repo_gguf, [jc_text_model], [jc_gguf], show_api=False)
|
| 62 |
jc_use_inference_client.change(change_text_model, [jc_text_model, jc_use_inference_client], [jc_text_model], show_api=False)
|
|
|
|
| 49 |
jc_temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.6, step=0.1, label="Temperature")
|
| 50 |
jc_topp = gr.Slider(minimum=0, maximum=2.0, value=0.9, step=0.01, label="Top-P")
|
| 51 |
jc_run_button = gr.Button("Caption", variant="primary")
|
|
|
|
| 52 |
with gr.Column():
|
| 53 |
jc_output_caption = gr.Textbox(label="Caption", show_copy_button=True)
|
| 54 |
gr.Markdown(JC_DESC_MD, elem_classes="info")
|
| 55 |
gr.LoginButton()
|
| 56 |
gr.DuplicateButton(value="Duplicate Space for private use (This demo does not work on CPU. Requires GPU Space)")
|
| 57 |
|
| 58 |
+
jc_run_button.click(fn=stream_chat_mod, inputs=[jc_input_image, jc_caption_type, jc_caption_tone, jc_caption_length,
|
| 59 |
+
jc_tokens, jc_topp, jc_temperature, jc_text_model], outputs=[jc_output_caption])
|
| 60 |
jc_text_model_button.click(change_text_model, [jc_text_model, jc_use_inference_client, jc_gguf, jc_nf4], [jc_text_model], show_api=False)
|
| 61 |
#jc_text_model.change(get_repo_gguf, [jc_text_model], [jc_gguf], show_api=False)
|
| 62 |
jc_use_inference_client.change(change_text_model, [jc_text_model, jc_use_inference_client], [jc_text_model], show_api=False)
|
joycaption.py
CHANGED
|
@@ -30,9 +30,11 @@ BASE_DIR = Path(__file__).resolve().parent
|
|
| 30 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 31 |
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
| 32 |
use_inference_client = False
|
|
|
|
| 33 |
|
| 34 |
llm_models = {
|
| 35 |
"bunnycore/LLama-3.1-8B-Matrix": None,
|
|
|
|
| 36 |
"Sao10K/Llama-3.1-8B-Stheno-v3.4": None,
|
| 37 |
"unsloth/Meta-Llama-3.1-8B-bnb-4bit": None,
|
| 38 |
"DevQuasar/HermesNova-Llama-3.1-8B": None,
|
|
@@ -123,7 +125,6 @@ class ImageAdapter(nn.Module):
|
|
| 123 |
def get_eot_embedding(self):
|
| 124 |
return self.other_tokens(torch.tensor([2], device=self.other_tokens.weight.device)).squeeze(0)
|
| 125 |
|
| 126 |
-
|
| 127 |
# https://huggingface.co/docs/transformers/v4.44.2/gguf
|
| 128 |
# https://github.com/city96/ComfyUI-GGUF/issues/7
|
| 129 |
# https://github.com/THUDM/ChatGLM-6B/issues/18
|
|
@@ -147,6 +148,15 @@ def load_text_model(model_name: str=MODEL_PATH, gguf_file: Union[str, None]=None
|
|
| 147 |
from transformers import BitsAndBytesConfig
|
| 148 |
nf4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4",
|
| 149 |
bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
|
| 151 |
print("Loading tokenizer")
|
| 152 |
if gguf_file: tokenizer = AutoTokenizer.from_pretrained(model_name, gguf_file=gguf_file, use_fast=True, legacy=False)
|
|
@@ -286,7 +296,8 @@ def stream_chat(input_image: Image.Image, caption_type: str, caption_tone: str,
|
|
| 286 |
|
| 287 |
@spaces.GPU()
|
| 288 |
@torch.no_grad()
|
| 289 |
-
def stream_chat_mod(input_image: Image.Image, caption_type: str, caption_tone: str, caption_length: Union[str, int],
|
|
|
|
| 290 |
global use_inference_client, text_model
|
| 291 |
torch.cuda.empty_cache()
|
| 292 |
gc.collect()
|
|
@@ -312,8 +323,15 @@ def stream_chat_mod(input_image: Image.Image, caption_type: str, caption_tone: s
|
|
| 312 |
prompt_str = CAPTION_TYPE_MAP[prompt_key][0].format(length=length, word_count=length)
|
| 313 |
print(f"Prompt: {prompt_str}")
|
| 314 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
# Preprocess image
|
| 316 |
-
#image = clip_processor(images=input_image, return_tensors='pt').pixel_values
|
| 317 |
image = input_image.resize((384, 384), Image.LANCZOS)
|
| 318 |
pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
|
| 319 |
pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
|
|
@@ -352,9 +370,6 @@ def stream_chat_mod(input_image: Image.Image, caption_type: str, caption_tone: s
|
|
| 352 |
attention_mask = torch.ones_like(input_ids)
|
| 353 |
|
| 354 |
text_model.to(device)
|
| 355 |
-
#generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=False, suppress_tokens=None)
|
| 356 |
-
#generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=True, top_k=10, temperature=0.5, suppress_tokens=None)
|
| 357 |
-
#generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=True, suppress_tokens=None) # Uses the default which is temp=0.6, top_p=0.9
|
| 358 |
generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=max_new_tokens,
|
| 359 |
do_sample=True, suppress_tokens=None, top_p=top_p, temperature=temperature)
|
| 360 |
|
|
|
|
| 30 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 31 |
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
| 32 |
use_inference_client = False
|
| 33 |
+
PIXTRAL_PATH = "mistral-community/pixtral-12b"
|
| 34 |
|
| 35 |
llm_models = {
|
| 36 |
"bunnycore/LLama-3.1-8B-Matrix": None,
|
| 37 |
+
#PIXTRAL_PATH: None,
|
| 38 |
"Sao10K/Llama-3.1-8B-Stheno-v3.4": None,
|
| 39 |
"unsloth/Meta-Llama-3.1-8B-bnb-4bit": None,
|
| 40 |
"DevQuasar/HermesNova-Llama-3.1-8B": None,
|
|
|
|
| 125 |
def get_eot_embedding(self):
|
| 126 |
return self.other_tokens(torch.tensor([2], device=self.other_tokens.weight.device)).squeeze(0)
|
| 127 |
|
|
|
|
| 128 |
# https://huggingface.co/docs/transformers/v4.44.2/gguf
|
| 129 |
# https://github.com/city96/ComfyUI-GGUF/issues/7
|
| 130 |
# https://github.com/THUDM/ChatGLM-6B/issues/18
|
|
|
|
| 148 |
from transformers import BitsAndBytesConfig
|
| 149 |
nf4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4",
|
| 150 |
bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)
|
| 151 |
+
|
| 152 |
+
if model_name == PIXTRAL_PATH:
|
| 153 |
+
from transformers import AutoProcessor, LlavaForConditionalGeneration
|
| 154 |
+
if is_nf4:
|
| 155 |
+
text_model = LlavaForConditionalGeneration.from_pretrained(model_name, quantization_config=nf4_config, device_map=device, torch_dtype=torch.bfloat16).eval()
|
| 156 |
+
image_adapter = AutoProcessor.from_pretrained(model_name, quantization_config=nf4_config, device_map=device, torch_dtype=torch.bfloat16)
|
| 157 |
+
else:
|
| 158 |
+
text_model = LlavaForConditionalGeneration.from_pretrained(model_name, device_map=device, torch_dtype=torch.bfloat16).eval()
|
| 159 |
+
image_adapter = AutoProcessor.from_pretrained(model_name, device_map=device, torch_dtype=torch.bfloat16)
|
| 160 |
|
| 161 |
print("Loading tokenizer")
|
| 162 |
if gguf_file: tokenizer = AutoTokenizer.from_pretrained(model_name, gguf_file=gguf_file, use_fast=True, legacy=False)
|
|
|
|
| 296 |
|
| 297 |
@spaces.GPU()
|
| 298 |
@torch.no_grad()
|
| 299 |
+
def stream_chat_mod(input_image: Image.Image, caption_type: str, caption_tone: str, caption_length: Union[str, int],
|
| 300 |
+
max_new_tokens: int=300, top_p: float=0.9, temperature: float=0.6, model_name: str=MODEL_PATH, progress=gr.Progress(track_tqdm=True)) -> str:
|
| 301 |
global use_inference_client, text_model
|
| 302 |
torch.cuda.empty_cache()
|
| 303 |
gc.collect()
|
|
|
|
| 323 |
prompt_str = CAPTION_TYPE_MAP[prompt_key][0].format(length=length, word_count=length)
|
| 324 |
print(f"Prompt: {prompt_str}")
|
| 325 |
|
| 326 |
+
# Pixtral
|
| 327 |
+
if model_name == PIXTRAL_PATH:
|
| 328 |
+
input_images = [input_image]
|
| 329 |
+
inputs = image_adapter(text=prompt_str, images=input_images, return_tensors="pt").to(device)
|
| 330 |
+
generate_ids = text_model.generate(**inputs, max_new_tokens=max_new_tokens)
|
| 331 |
+
output = image_adapter.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 332 |
+
return output.strip()
|
| 333 |
+
|
| 334 |
# Preprocess image
|
|
|
|
| 335 |
image = input_image.resize((384, 384), Image.LANCZOS)
|
| 336 |
pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
|
| 337 |
pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
|
|
|
|
| 370 |
attention_mask = torch.ones_like(input_ids)
|
| 371 |
|
| 372 |
text_model.to(device)
|
|
|
|
|
|
|
|
|
|
| 373 |
generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=max_new_tokens,
|
| 374 |
do_sample=True, suppress_tokens=None, top_p=top_p, temperature=temperature)
|
| 375 |
|