Hug0endob's picture
Create app.py
71b45b9 verified
raw
history blame
2.14 kB
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration
import gradio as gr
from PIL import Image
import requests
# -------------------------------------------------
# Model identifier (replace if you fork or use a different checkpoint)
# -------------------------------------------------
MODEL_NAME = "fpgaminer/joycaption-llama3.1-8b" # 8‑B checkpoint fits comfortably on CPU
# -------------------------------------------------
# Load processor and model (CPU only)
# -------------------------------------------------
processor = AutoProcessor.from_pretrained(MODEL_NAME)
# `device_map="cpu"` forces everything onto the CPU
llava_model = LlavaForConditionalGeneration.from_pretrained(
MODEL_NAME,
device_map="cpu",
torch_dtype=torch.bfloat16, # native dtype for this model
)
llava_model.eval()
# -------------------------------------------------
# Inference function used by Gradio
# -------------------------------------------------
def generate_caption(image: Image.Image, prompt: str = "Describe the image.") -> str:
# Prepare inputs for the model
inputs = processor(images=image, text=prompt, return_tensors="pt")
inputs = {k: v.to(llava_model.device) for k, v in inputs.items()}
# Generate up to 64 new tokens (adjust if you want longer captions)
with torch.no_grad():
output_ids = llava_model.generate(**inputs, max_new_tokens=64)
# Decode to plain text
caption = processor.decode(output_ids[0], skip_special_tokens=True)
return caption
# -------------------------------------------------
# Gradio UI
# -------------------------------------------------
iface = gr.Interface(
fn=generate_caption,
inputs=[
gr.Image(type="pil", label="Upload an image"),
gr.Textbox(label="Prompt (optional)", value="Describe the image.")
],
outputs=gr.Textbox(label="Generated caption"),
title="JoyCaption (CPU‑only) Demo",
description="Upload an image and let the JoyCaption model generate a caption. Runs entirely on the free CPU tier.",
allow_flagging="never"
)
if __name__ == "__main__":
iface.launch()