caption-space / app.py
sciencellama's picture
caption-space: deploy app.py
5fbecfc verified
Raw
History Blame Contribute Delete
6.17 kB
"""
caption-space — vision-language descriptions of inspiration images.
What this does:
Takes one or more inspiration images (base64, or a JSON array of base64),
returns a natural-language style description (interior design vocabulary).
Why LLaVA-1.5-7B:
- Open source, no licensing issues
- Strong style-vocabulary captioning (trained on web image-text pairs that
include design/decor blogs)
- Standard transformers pipeline (well-tested, stable API)
- 7B params, ~14 GB at fp16 — fits ZeroGPU A10G comfortably
Use case:
Frontend "Describe this mood board →" button →
POST /describe-inspiration (FastAPI) →
this Space's /describe →
returns prose for user to edit before submitting the redesign job.
Multi-image inputs:
LLaVA-1.5 only takes one image per call. If the user passes a JSON array of
N images, we caption each separately and concatenate with " Also: ".
The user can manually edit the result to merge — fine for POC. Future:
swap to LLaVA-OneVision or Qwen2-VL which handle multi-image natively.
API contract:
api_name="/describe"
Inputs (positional):
1. images_b64: str — single base64 PNG/JPEG OR JSON-array string
2. instruction: str — optional override, "" = use default
Output: {"description": str}
"""
import base64
import io
import json
import os
import traceback
import gradio as gr
import spaces
import torch
from PIL import Image
from transformers import AutoProcessor, LlavaForConditionalGeneration
MODEL_ID = "llava-hf/llava-1.5-7b-hf"
DEFAULT_INSTRUCTION = (
"You are an interior design expert. Describe the design style of this room "
"in 2-3 sentences. Mention: style/era (e.g. Japandi, Parisian classical, "
"industrial loft, mid-century modern), color palette, materials, key "
"furniture pieces, patterns/textures, and overall vibe. Be specific with "
"design vocabulary."
)
HF_TOKEN = os.environ.get("HF_TOKEN", "")
# ---------------------------------------------------------------------------
# Model load (CPU at startup, moves to GPU inside @spaces.GPU function)
# ---------------------------------------------------------------------------
print(f"[caption-space] loading {MODEL_ID}...")
processor = AutoProcessor.from_pretrained(
MODEL_ID,
token=HF_TOKEN if HF_TOKEN else None,
)
model = LlavaForConditionalGeneration.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
token=HF_TOKEN if HF_TOKEN else None,
)
print("[caption-space] model loaded on CPU (will move to GPU on first call).")
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _b64_to_pil(b64: str) -> Image.Image:
return Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB")
def _parse_images(images_b64: str) -> list[Image.Image]:
"""Accept either a single base64 string or a JSON array of base64 strings."""
try:
parsed = json.loads(images_b64)
if isinstance(parsed, list):
return [_b64_to_pil(s) for s in parsed]
except (json.JSONDecodeError, TypeError, ValueError):
pass
return [_b64_to_pil(images_b64)]
# ---------------------------------------------------------------------------
# Inference
# ---------------------------------------------------------------------------
@spaces.GPU(duration=60)
def _generate_one(image: Image.Image, instruction: str) -> str:
model.to("cuda")
conversation = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": instruction},
],
},
]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
inputs = processor(
images=image,
text=prompt,
return_tensors="pt",
).to("cuda", torch.float16)
out = model.generate(
**inputs,
max_new_tokens=200,
do_sample=False,
)
# Decode only the newly generated tokens (skip the prompt)
prompt_len = inputs.input_ids.shape[1]
text = processor.decode(out[0][prompt_len:], skip_special_tokens=True)
return text.strip()
# ---------------------------------------------------------------------------
# Main endpoint
# ---------------------------------------------------------------------------
def describe(images_b64: str, instruction: str = "") -> dict:
try:
instr = instruction.strip() if instruction else DEFAULT_INSTRUCTION
images = _parse_images(images_b64)
print(f"[describe] {len(images)} image(s), instr={instr[:80]!r}")
descriptions = []
for i, img in enumerate(images):
desc = _generate_one(img, instr)
print(f"[describe] image {i}: {desc[:100]!r}")
descriptions.append(desc)
final = descriptions[0] if len(descriptions) == 1 else " Also: ".join(descriptions)
return {"description": final}
except Exception as e:
traceback.print_exc()
raise ValueError(f"describe failed: {type(e).__name__}: {e}") from e
# ---------------------------------------------------------------------------
# Gradio interface
# ---------------------------------------------------------------------------
with gr.Blocks(title="Caption Space (LLaVA-1.5-7B)") as demo:
gr.Markdown(
"## LLaVA-1.5-7B — Inspiration Image Captioner\n\n"
"Describes interior-design style, colors, materials, and key furniture in "
"natural language. Used by the main pipeline to let users seed a style "
"prompt from a mood board."
)
images_in = gr.Textbox(
label="images_b64 (single base64 OR JSON array of base64 strings)",
lines=4,
)
instr_in = gr.Textbox(
label="instruction (optional override)",
value=DEFAULT_INSTRUCTION,
lines=3,
)
gr.Button("Describe").click(
describe,
inputs=[images_in, instr_in],
outputs=gr.JSON(label="result"),
api_name="describe",
)
demo.launch(show_error=True)