Spaces:
Build error
Build error
| """ | |
| caption-space — vision-language descriptions of inspiration images. | |
| What this does: | |
| Takes one or more inspiration images (base64, or a JSON array of base64), | |
| returns a natural-language style description (interior design vocabulary). | |
| Why LLaVA-1.5-7B: | |
| - Open source, no licensing issues | |
| - Strong style-vocabulary captioning (trained on web image-text pairs that | |
| include design/decor blogs) | |
| - Standard transformers pipeline (well-tested, stable API) | |
| - 7B params, ~14 GB at fp16 — fits ZeroGPU A10G comfortably | |
| Use case: | |
| Frontend "Describe this mood board →" button → | |
| POST /describe-inspiration (FastAPI) → | |
| this Space's /describe → | |
| returns prose for user to edit before submitting the redesign job. | |
| Multi-image inputs: | |
| LLaVA-1.5 only takes one image per call. If the user passes a JSON array of | |
| N images, we caption each separately and concatenate with " Also: ". | |
| The user can manually edit the result to merge — fine for POC. Future: | |
| swap to LLaVA-OneVision or Qwen2-VL which handle multi-image natively. | |
| API contract: | |
| api_name="/describe" | |
| Inputs (positional): | |
| 1. images_b64: str — single base64 PNG/JPEG OR JSON-array string | |
| 2. instruction: str — optional override, "" = use default | |
| Output: {"description": str} | |
| """ | |
| import base64 | |
| import io | |
| import json | |
| import os | |
| import traceback | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| from PIL import Image | |
| from transformers import AutoProcessor, LlavaForConditionalGeneration | |
| MODEL_ID = "llava-hf/llava-1.5-7b-hf" | |
| DEFAULT_INSTRUCTION = ( | |
| "You are an interior design expert. Describe the design style of this room " | |
| "in 2-3 sentences. Mention: style/era (e.g. Japandi, Parisian classical, " | |
| "industrial loft, mid-century modern), color palette, materials, key " | |
| "furniture pieces, patterns/textures, and overall vibe. Be specific with " | |
| "design vocabulary." | |
| ) | |
| HF_TOKEN = os.environ.get("HF_TOKEN", "") | |
| # --------------------------------------------------------------------------- | |
| # Model load (CPU at startup, moves to GPU inside @spaces.GPU function) | |
| # --------------------------------------------------------------------------- | |
| print(f"[caption-space] loading {MODEL_ID}...") | |
| processor = AutoProcessor.from_pretrained( | |
| MODEL_ID, | |
| token=HF_TOKEN if HF_TOKEN else None, | |
| ) | |
| model = LlavaForConditionalGeneration.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.float16, | |
| low_cpu_mem_usage=True, | |
| token=HF_TOKEN if HF_TOKEN else None, | |
| ) | |
| print("[caption-space] model loaded on CPU (will move to GPU on first call).") | |
| # --------------------------------------------------------------------------- | |
| # Helpers | |
| # --------------------------------------------------------------------------- | |
| def _b64_to_pil(b64: str) -> Image.Image: | |
| return Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB") | |
| def _parse_images(images_b64: str) -> list[Image.Image]: | |
| """Accept either a single base64 string or a JSON array of base64 strings.""" | |
| try: | |
| parsed = json.loads(images_b64) | |
| if isinstance(parsed, list): | |
| return [_b64_to_pil(s) for s in parsed] | |
| except (json.JSONDecodeError, TypeError, ValueError): | |
| pass | |
| return [_b64_to_pil(images_b64)] | |
| # --------------------------------------------------------------------------- | |
| # Inference | |
| # --------------------------------------------------------------------------- | |
| def _generate_one(image: Image.Image, instruction: str) -> str: | |
| model.to("cuda") | |
| conversation = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image"}, | |
| {"type": "text", "text": instruction}, | |
| ], | |
| }, | |
| ] | |
| prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) | |
| inputs = processor( | |
| images=image, | |
| text=prompt, | |
| return_tensors="pt", | |
| ).to("cuda", torch.float16) | |
| out = model.generate( | |
| **inputs, | |
| max_new_tokens=200, | |
| do_sample=False, | |
| ) | |
| # Decode only the newly generated tokens (skip the prompt) | |
| prompt_len = inputs.input_ids.shape[1] | |
| text = processor.decode(out[0][prompt_len:], skip_special_tokens=True) | |
| return text.strip() | |
| # --------------------------------------------------------------------------- | |
| # Main endpoint | |
| # --------------------------------------------------------------------------- | |
| def describe(images_b64: str, instruction: str = "") -> dict: | |
| try: | |
| instr = instruction.strip() if instruction else DEFAULT_INSTRUCTION | |
| images = _parse_images(images_b64) | |
| print(f"[describe] {len(images)} image(s), instr={instr[:80]!r}") | |
| descriptions = [] | |
| for i, img in enumerate(images): | |
| desc = _generate_one(img, instr) | |
| print(f"[describe] image {i}: {desc[:100]!r}") | |
| descriptions.append(desc) | |
| final = descriptions[0] if len(descriptions) == 1 else " Also: ".join(descriptions) | |
| return {"description": final} | |
| except Exception as e: | |
| traceback.print_exc() | |
| raise ValueError(f"describe failed: {type(e).__name__}: {e}") from e | |
| # --------------------------------------------------------------------------- | |
| # Gradio interface | |
| # --------------------------------------------------------------------------- | |
| with gr.Blocks(title="Caption Space (LLaVA-1.5-7B)") as demo: | |
| gr.Markdown( | |
| "## LLaVA-1.5-7B — Inspiration Image Captioner\n\n" | |
| "Describes interior-design style, colors, materials, and key furniture in " | |
| "natural language. Used by the main pipeline to let users seed a style " | |
| "prompt from a mood board." | |
| ) | |
| images_in = gr.Textbox( | |
| label="images_b64 (single base64 OR JSON array of base64 strings)", | |
| lines=4, | |
| ) | |
| instr_in = gr.Textbox( | |
| label="instruction (optional override)", | |
| value=DEFAULT_INSTRUCTION, | |
| lines=3, | |
| ) | |
| gr.Button("Describe").click( | |
| describe, | |
| inputs=[images_in, instr_in], | |
| outputs=gr.JSON(label="result"), | |
| api_name="describe", | |
| ) | |
| demo.launch(show_error=True) | |