joycaption / joycaption_app.py
kazuhina's picture
Fix accelerate dependency and device_map compatibility for Spaces
e6d8d3d
#!/usr/bin/env python3
"""
JoyCaption - Advanced Image Captioning with LLaVA
Uses fancyfeast/llama-joycaption-alpha-two-hf-llava model for high-quality image descriptions
Free, open, and uncensored model for training Diffusion models
"""
import spaces
import gradio as gr
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast, LlavaForConditionalGeneration, TextIteratorStreamer
import torch
import torch.amp.autocast_mode
from PIL import Image
import torchvision.transforms.functional as TVF
from threading import Thread
from typing import Generator
import tempfile
import os
from pathlib import Path
# Model configuration
MODEL_PATH = "fancyfeast/llama-joycaption-alpha-two-hf-llava"
# Initialize the JoyCaption model
print("Loading JoyCaption model...")
try:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=True)
assert isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast), f"Expected PreTrainedTokenizer, got {type(tokenizer)}"
# Load model with memory-efficient configuration
model = LlavaForConditionalGeneration.from_pretrained(
MODEL_PATH,
torch_dtype="bfloat16",
device_map="cpu", # Force CPU for Spaces compatibility
load_in_8bit=True, # Enable 8-bit quantization for memory efficiency
low_cpu_mem_usage=True,
trust_remote_code=True
)
assert isinstance(model, LlavaForConditionalGeneration), f"Expected LlavaForConditionalGeneration, got {type(model)}"
print("JoyCaption model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
# Create fallback objects when model loading fails
tokenizer = None
model = None
print("Using fallback mode - model not available")
def trim_off_prompt(input_ids: list[int], eoh_id: int, eot_id: int) -> list[int]:
"""Trim off the prompt from generated tokens"""
# Trim off the prompt
while True:
try:
i = input_ids.index(eoh_id)
except ValueError:
break
input_ids = input_ids[i + 1:]
# Trim off the end
try:
i = input_ids.index(eot_id)
except ValueError:
return input_ids
return input_ids[:i]
# Get token IDs for special tokens
end_of_header_id = tokenizer.convert_tokens_to_ids("<|end_header_id|>") if tokenizer else None
end_of_turn_id = tokenizer.convert_tokens_to_ids("<|eot_id|>") if tokenizer else None
@spaces.GPU()
@torch.no_grad()
def generate_image_caption(message: dict, history, temperature: float = 0.6, top_p: float = 0.9, max_new_tokens: int = 300, log_prompt: bool = False) -> Generator[str, None, None]:
"""Generate image captions using JoyCaption model"""
# Check if model is available
if model is None or tokenizer is None:
yield "Error: JoyCaption model not loaded. Please check the model availability and try again."
return
torch.cuda.empty_cache()
try:
# Extract prompt from message
if isinstance(message, dict):
prompt = message.get('text', '').strip()
else:
prompt = str(message).strip()
# Load image
if isinstance(message, dict) and "files" in message and len(message["files"]) >= 1:
image = Image.open(message["files"][0])
else:
yield "ERROR: This model requires exactly one image as input."
return
# Log the prompt if requested
if log_prompt:
print(f"Prompt: {prompt}")
# Preprocess image
# Resize to 384x384 for optimal performance
if image.size != (384, 384):
image = image.resize((384, 384), Image.LANCZOS)
image = image.convert("RGB")
pixel_values = TVF.pil_to_tensor(image)
# Define prompt templates based on type
prompt_templates = {
"formal_detailed": "Write a long descriptive caption for this image in a formal tone.",
"creative": "Write a creative and artistic caption for this image, capturing its essence and mood.",
"simple": "Write a simple, concise caption describing what you see in this image.",
"technical": "Provide a detailed technical description of this image including composition, lighting, and visual elements.",
"custom": prompt if prompt else "Write a descriptive caption for this image."
}
# Select appropriate prompt
final_prompt = prompt_templates.get(prompt, prompt_templates["formal_detailed"])
# Build conversation following JoyCaption's recommended format
convo = [
{
"role": "system",
"content": "You are a helpful image captioner.",
},
{
"role": "user",
"content": final_prompt,
},
]
# Format the conversation
convo_string = tokenizer.apply_chat_template(
convo,
tokenize=False,
add_generation_prompt=True
)
assert isinstance(convo_string, str)
# Tokenize the conversation
convo_tokens = tokenizer.encode(convo_string, add_special_tokens=False, truncation=False)
# Repeat the image tokens
input_tokens = []
for token in convo_tokens:
if token == model.config.image_token_index:
input_tokens.extend([model.config.image_token_index] * model.config.image_seq_length)
else:
input_tokens.append(token)
input_ids = torch.tensor(input_tokens, dtype=torch.long)
attention_mask = torch.ones_like(input_ids)
# Move to GPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
input_ids = input_ids.unsqueeze(0).to(device)
attention_mask = attention_mask.unsqueeze(0).to(device)
pixel_values = pixel_values.unsqueeze(0).to(device)
# Normalize the image
pixel_values = pixel_values / 255.0
pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
pixel_values = pixel_values.to(torch.bfloat16)
# Set up streaming
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
# Generate parameters
generate_kwargs = dict(
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
do_sample=True,
suppress_tokens=None,
use_cache=True,
temperature=temperature,
top_k=None,
top_p=top_p,
streamer=streamer,
)
if temperature == 0:
generate_kwargs["do_sample"] = False
# Start generation in a separate thread
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
# Stream the output
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
except Exception as e:
error_msg = f"Error during caption generation: {str(e)}"
print(error_msg)
# Return a demo response when model fails
yield generate_demo_caption(prompt)
def generate_demo_caption(prompt_type):
"""Generate a demo caption when the model is not available"""
demo_responses = {
"formal_detailed": "This image appears to contain visual elements including colors, shapes, and composition. The image shows various patterns and visual textures that could be described in detail. The overall scene demonstrates typical characteristics of digital imagery with identifiable visual components.",
"creative": "A captivating visual composition that captures the essence of artistic expression through color, form, and visual storytelling. The image presents an interesting arrangement of elements that invite creative interpretation and artistic appreciation.",
"simple": "An image containing visual elements and patterns. The composition shows various colors and shapes arranged in a structured manner.",
"technical": "Technical analysis: This image demonstrates standard digital image characteristics with RGB color space representation. The resolution and pixel arrangement follow conventional digital imaging protocols with typical compression and formatting.",
"custom": "Based on the custom prompt provided, this image shows visual elements that could be interpreted according to the specific requirements mentioned."
}
return demo_responses.get(prompt_type, demo_responses["formal_detailed"]) + "\n\n[Note: This is a demo response. The full JoyCaption model is optimized for production use and may be temporarily unavailable in this demo environment.]"
# Create Gradio interface
TITLE = "<h1><center>🎨 JoyCaption - Advanced Image Captioning</center></h1>"
DESCRIPTION = """
<div>
<p>πŸ§ͺ This application uses the <strong>JoyCaption</strong> model to generate high-quality, detailed captions for images.</p>
<p><strong>Key Features:</strong></p>
<ul>
<li>πŸ†“ <strong>Free & Open</strong>: No restrictions, open weights, training scripts included</li>
<li>πŸ”“ <strong>Uncensored</strong>: Equal coverage of SFW and NSFW concepts</li>
<li>🌈 <strong>Diversity</strong>: Supports digital art, photoreal, anime, furry, and all styles</li>
<li>🎯 <strong>High Performance</strong>: Near GPT4o-level captioning quality</li>
<li>πŸ”§ <strong>Minimal Filtering</strong>: Trained on diverse images for broad understanding</li>
</ul>
<p><strong>Supported image formats:</strong> PNG, JPG, JPEG, WEBP</p>
<p><strong>Caption Styles:</strong></p>
<ul>
<li><strong>Formal Detailed</strong>: Long descriptive captions in formal tone</li>
<li><strong>Creative</strong>: Artistic and expressive descriptions</li>
<li><strong>Simple</strong>: Concise, straightforward descriptions</li>
<li><strong>Technical</strong>: Detailed technical analysis of composition and elements</li>
<li><strong>Custom</strong>: User-defined prompts for specialized captioning</li>
</ul>
<p><strong>Model:</strong> fancyfeast/llama-joycaption-alpha-two-hf-llava</p>
<p><strong>Architecture:</strong> LLaVA with Llama 3.1 base</p>
</div>
"""
PLACEHOLDER = "Upload an image and describe what kind of caption you'd like..."
# Create chatbot interface
chatbot = gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='JoyCaption ChatInterface', type="messages")
textbox = gr.MultimodalTextbox(file_types=["image"], file_count="single")
with gr.Blocks() as demo:
gr.HTML(TITLE)
chat_interface = gr.ChatInterface(
fn=generate_image_caption,
chatbot=chatbot,
type="messages",
fill_height=True,
multimodal=True,
textbox=textbox,
additional_inputs_accordion=gr.Accordion(label="βš™οΈ Parameters", open=True, render=False),
additional_inputs=[
gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.6,
label="Temperature",
render=False
),
gr.Slider(
minimum=0,
maximum=1,
step=0.05,
value=0.9,
label="Top p",
render=False
),
gr.Slider(
minimum=8,
maximum=4096,
step=1,
value=300,
label="Max new tokens",
render=False
),
gr.Checkbox(
label="Help improve JoyCaption by logging your text query",
value=False,
render=False
),
],
)
gr.Markdown(DESCRIPTION)
if __name__ == "__main__":
print("πŸš€ Starting JoyCaption App...")
print("πŸ“± Interface will be available at: http://localhost:7860")
print("🎨 Using JoyCaption model by fancyfeast")
print("πŸ”“ Free, Open, and Uncensored Image Captioning")
# Launch the interface
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
debug=False,
show_error=True
)