bharatcoder's picture
Update app.py
0a09255 verified
raw
history blame
3.49 kB
import gradio as gr
from transformers import AutoProcessor, AutoModelForImageTextToText
from PIL import Image
import base64
from io import BytesIO
import os
# -----------------------------
# Load model and processor once
# -----------------------------
processor = AutoProcessor.from_pretrained("ds4sd/SmolDocling-256M-preview")
model = AutoModelForImageTextToText.from_pretrained("ds4sd/SmolDocling-256M-preview")
# -----------------------------
# Image conversion helper
# -----------------------------
def convert_to_pil(image_input):
"""
Convert base64, dict, or file path to PIL.Image.
Handles:
- "data:image/png;base64,...."
- plain base64
- {"type": "image", "data": "..."}
- file path
"""
# Case 1: dict input (Perplexity/Claude format)
if isinstance(image_input, dict) and "data" in image_input:
image_input = image_input["data"]
# Case 2: base64 string with prefix
if isinstance(image_input, str) and image_input.startswith("data:image"):
base64_str = image_input.split(",", 1)[1]
image_data = base64.b64decode(base64_str)
return Image.open(BytesIO(image_data))
# Case 3: plain base64 string (no prefix)
if isinstance(image_input, str) and "," in image_input and len(image_input) > 100:
try:
image_data = base64.b64decode(image_input)
return Image.open(BytesIO(image_data))
except Exception:
pass
# Case 4: local file path
if isinstance(image_input, str) and os.path.exists(image_input):
return Image.open(image_input)
raise ValueError("Could not convert image input to PIL.Image")
# -----------------------------
# Core function
# -----------------------------
def smoldocling_readimage(image: Image.Image, prompt_text: str) -> str:
"""
Run SmolDocling image-to-text conversion.
"""
messages = [
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt_text}]}
]
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(text=prompt, images=[image], return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=1024)
prompt_length = inputs.input_ids.shape[1]
generated = outputs[:, prompt_length:]
result = processor.batch_decode(generated, skip_special_tokens=False)[0]
return result.replace("<end_of_utterance>", "").strip()
# -----------------------------
# Wrapper for MCP schema compatibility
# -----------------------------
def smoldocling_entry(image, prompt_text: str) -> str:
"""
Entry point for MCP tool.
Accepts any of:
- base64 string
- dict {"type": "image", "data": "data:image/png;base64,..."}
- file path
"""
pil_image = convert_to_pil(image)
return smoldocling_readimage(pil_image, prompt_text)
# -----------------------------
# Gradio MCP App (Headless)
# -----------------------------
with gr.Blocks() as demo:
gr.Markdown(
"""
### 📄 SmolDocling MCP Tool
This is a **headless MCP tool** for document image conversion.
It supports input as:
- Base64-encoded images
- Perplexity/Claude `{"type": "image", "data": "..."}` objects
- Local file paths (for testing)
"""
)
# Expose MCP tool
gr.api(smoldocling_entry)
# Launch MCP server mode
_, url, _ = demo.launch(mcp_server=True)
print(f"✅ MCP Server running at: {url}")