ak / app.py
AkashKumarave's picture
Update app.py
68e3db1 verified
raw
history blame
10.1 kB
import gradio as gr
import vtracer
import os
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
import io
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI()
# Configure CORS to allow requests from Figma plugin
app.add_middleware(
CORSMiddleware,
allow_origins=["https://www.figma.com", "*"], # Allow Figma and local testing
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# VTracer conversion function
def convert_to_vector(
image,
colormode="color",
hierarchical="stacked",
mode="spline",
filter_speckle=4,
color_precision=6,
layer_difference=16,
corner_threshold=60,
length_threshold=4.0,
max_iterations=10,
splice_threshold=45,
path_precision=3
):
input_path = "temp_input.jpg"
output_path = "svg_output.svg"
try:
# Save the input image to a temporary file
image.save(input_path)
logger.info(f"Saved image to {input_path}")
# Convert the image to SVG using VTracer
vtracer.convert_image_to_svg_py(
input_path,
output_path,
colormode=colormode,
hierarchical=hierarchical,
mode=mode,
filter_speckle=int(filter_speckle),
color_precision=int(color_precision),
layer_difference=int(layer_difference),
corner_threshold=int(corner_threshold),
length_threshold=float(length_threshold),
max_iterations=int(max_iterations),
splice_threshold=int(splice_threshold),
path_precision=int(path_precision)
)
logger.info(f"Converted image to SVG at {output_path}")
# Read the SVG output
with open(output_path, "r") as f:
svg_content = f.read()
return svg_content
except Exception as e:
logger.error(f"Error in convert_to_vector: {str(e)}")
raise HTTPException(status_code=500, detail=f"Conversion failed: {str(e)}")
finally:
# Clean up temporary files
for path in [input_path, output_path]:
if os.path.exists(path):
try:
os.remove(path)
logger.info(f"Removed {path}")
except Exception as e:
logger.warning(f"Failed to remove {path}: {str(e)}")
# FastAPI endpoint for vector conversion
@app.post("/convert")
async def convert_image(
file: UploadFile = File(...),
colormode: str = Form("color"),
hierarchical: str = Form("stacked"),
mode: str = Form("spline"),
filter_speckle: int = Form(4),
color_precision: int = Form(6),
layer_difference: int = Form(16),
corner_threshold: int = Form(60),
length_threshold: float = Form(4.0),
max_iterations: int = Form(10),
splice_threshold: int = Form(45),
path_precision: int = Form(3)
):
try:
logger.info("Received request to /convert")
# Read the uploaded image
image_data = await file.read()
image = Image.open(io.BytesIO(image_data))
# Convert to SVG
svg_content = convert_to_vector(
image,
colormode,
hierarchical,
mode,
filter_speckle,
color_precision,
layer_difference,
corner_threshold,
length_threshold,
max iterations=max_iterations,
splice_threshold=splice_threshold,
path_precision=path_precision
)
return JSONResponse(content={"svg": svg_content})
except Exception as e:
logger.error(f"Error in convert_image: {str(e)}")
return JSONResponse(content={"error": str(e)}, status_code=500)
# Health check endpoint
@app.get("/")
async def health_check():
logger.info("Health check requested")
return {"status": "healthy"}
# Gradio interface
def handle_color_mode(value):
return value
# Check if examples directory exists, else use empty list
examples_dir = "examples"
examples = [
os.path.join(examples_dir, f) for f in ["11.jpg", "02.jpg", "03.jpg"]
if os.path.exists(os.path.join(examples_dir, f))
]
css = """
#col-container {
margin: 0 auto;
max-width: 960px;
}
.generate-btn {
background: linear-gradient(90deg, #4B79A1 0%, #283E51 100%) !important;
border: none !important;
color: white !important;
}
.generate-btn:hover {
transform: translateY(-2px);
box-shadow: 0 5px 15px rgba(0,0,0,0.2);
}
"""
# Define the Gradio interface
with gr.Blocks(css=css) as gradio_app:
with gr.Column(elem_id="col-container"):
gr.HTML("""
<div style="text-align: center;">
<h2>Image to Vector Converter ⚡</h2>
<p>Converts raster images (JPG, PNG, WEBP) to vector graphics (SVG).</p>
</div>
""")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload Image")
with gr.Accordion("Advanced Settings", open=False):
with gr.Accordion("Clustering", open=False):
colormode = gr.Radio([("COLOR", "color"), ("B/W", "binary")], value="color", label="Color Mode", show_label=False)
filter_speckle = gr.Slider(0, 128, value=4, step=1, label="Filter Speckle", info="Cleaner")
color_precision = gr.Slider(1, 8, value=6, step=1, label="Color Precision", info="More accurate")
layer_difference = gr.Slider(0, 128, value=16, step=1, label="Gradient Step", info="Less layers")
hierarchical = gr.Radio([("STACKED", "stacked"), ("CUTOUT", "cutout")], value="stacked", label="Hierarchical Mode", show_label=False)
with gr.Accordion("Curve Fitting", open=False):
mode = gr.Radio([("SPLINE", "spline"), ("POLYGON", "polygon"), ("PIXEL", "none")], value="spline", label="Mode", show_label=False)
corner_threshold = gr.Slider(0, 180, value=60, step=1, label="Corner Threshold", info="Smoother")
length_threshold = gr.Slider(3.5, 10, value=4.0, step=0.1, label="Segment Length", info="More coarse")
splice_threshold = gr.Slider(0, 180, value=45, step=1, label="Splice Threshold", info="Less accurate")
max_iterations = gr.Slider(1, 20, value=10, step=1, label="Max Iterations", visible=False)
path_precision = gr.Slider(1, 10, value=3, step=1, label="Path Precision", visible=False)
output_text = gr.Textbox(label="Selected Mode", visible=False)
with gr.Row():
clear_button = gr.Button("Clear")
convert_button = gr.Button("✨ Convert to SVG", variant="primary", elem_classes=["generate-btn"])
with gr.Column():
html = gr.HTML(label="SVG Output")
svg_output = gr.File(label="Download SVG")
if examples:
gr.Examples(
examples=examples,
fn=convert_to_vector,
inputs=[image_input],
outputs=[html, svg_output],
cache_examples=False,
run_on_click=True
)
colormode.change(handle_color_mode, inputs=colormode, outputs=output_text)
hierarchical.change(handle_color_mode, inputs=hierarchical, outputs=output_text)
mode.change(handle_color_mode, inputs=mode, outputs=output_text)
default_values = {
"color_precision": 6,
"layer_difference": 16
}
def clear_inputs():
return (
gr.Image(value=None), gr.Radio(value="color"), gr.Radio(value="stacked"),
gr.Radio(value="spline"), gr.Slider(value=4), gr.Slider(value=6),
gr.Slider(value=16), gr.Slider(value=60), gr.Slider(value=4.0),
gr.Slider(value=10), gr.Slider(value=45), gr.Slider(value=3)
)
def update_interactivity_and_visibility(colormode, color_precision_value, layer_difference_value):
is_color_mode = colormode == "color"
return (
gr.update(interactive=is_color_mode),
gr.update(interactive=is_color_mode),
gr.update(visible=is_color_mode)
)
colormode.change(
update_interactivity_and_visibility,
inputs=[colormode, color_precision, layer_difference],
outputs=[color_precision, layer_difference, hierarchical]
)
def update_interactivity_and_visibility_for_mode(mode):
is_spline_mode = mode == "spline"
return (
gr.update(interactive=is_spline_mode),
gr.update(interactive=is_spline_mode),
gr.update(interactive=is_spline_mode)
)
mode.change(
update_interactivity_and_visibility_for_mode,
inputs=[mode],
outputs=[corner_threshold, length_threshold, splice_threshold]
)
clear_button.click(
clear_inputs,
outputs=[
image_input, colormode, hierarchical, mode, filter_speckle,
color_precision, layer_difference, corner_threshold, length_threshold,
max_iterations, splice_threshold, path_precision
]
)
convert_button.click(
convert_to_vector,
inputs=[
image_input, colormode, hierarchical, mode, filter_speckle,
color_precision, layer_difference, corner_threshold, length_threshold,
max_iterations, splice_threshold, path_precision
],
outputs=[html, svg_output]
)
# Mount Gradio app to FastAPI at /gradio
try:
from gradio import mount_gradio_app
app = mount_gradio_app(app, gradio_app, path="/gradio")
logger.info("Gradio app mounted successfully at /gradio")
except Exception as e:
logger.error(f"Failed to mount Gradio app: {str(e)}")