AkashKumarave commited on
Commit
68e3db1
·
verified ·
1 Parent(s): 747ac08

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -33
app.py CHANGED
@@ -18,13 +18,13 @@ app = FastAPI()
18
  # Configure CORS to allow requests from Figma plugin
19
  app.add_middleware(
20
  CORSMiddleware,
21
- allow_origins=["https://www.figma.com"],
22
  allow_credentials=True,
23
  allow_methods=["*"],
24
  allow_headers=["*"],
25
  )
26
 
27
- # Existing vtracer conversion function
28
  def convert_to_vector(
29
  image,
30
  colormode="color",
@@ -45,6 +45,7 @@ def convert_to_vector(
45
  try:
46
  # Save the input image to a temporary file
47
  image.save(input_path)
 
48
 
49
  # Convert the image to SVG using VTracer
50
  vtracer.convert_image_to_svg_py(
@@ -62,21 +63,23 @@ def convert_to_vector(
62
  splice_threshold=int(splice_threshold),
63
  path_precision=int(path_precision)
64
  )
 
65
 
66
  # Read the SVG output
67
  with open(output_path, "r") as f:
68
  svg_content = f.read()
69
 
70
  return svg_content
71
- # except Exception as e:
72
- # logger.error(f"Error in convert_to_vector: {str(e)}")
73
- # raise HTTPException(status_code=500, detail=f"Conversion failed: {str(e)}")
74
  finally:
75
  # Clean up temporary files
76
  for path in [input_path, output_path]:
77
  if os.path.exists(path):
78
  try:
79
  os.remove(path)
 
80
  except Exception as e:
81
  logger.warning(f"Failed to remove {path}: {str(e)}")
82
 
@@ -97,6 +100,7 @@ async def convert_image(
97
  path_precision: int = Form(3)
98
  ):
99
  try:
 
100
  # Read the uploaded image
101
  image_data = await file.read()
102
  image = Image.open(io.BytesIO(image_data))
@@ -112,9 +116,9 @@ async def convert_image(
112
  layer_difference,
113
  corner_threshold,
114
  length_threshold,
115
- max_iterations,
116
- splice_threshold,
117
- path_precision
118
  )
119
 
120
  return JSONResponse(content={"svg": svg_content})
@@ -122,14 +126,21 @@ async def convert_image(
122
  logger.error(f"Error in convert_image: {str(e)}")
123
  return JSONResponse(content={"error": str(e)}, status_code=500)
124
 
 
 
 
 
 
 
125
  # Gradio interface
126
  def handle_color_mode(value):
127
  return value
128
 
 
 
129
  examples = [
130
- "examples/11.jpg",
131
- "examples/02.jpg",
132
- "examples/03.jpg",
133
  ]
134
 
135
  css = """
@@ -162,39 +173,41 @@ with gr.Blocks(css=css) as gradio_app:
162
  image_input = gr.Image(type="pil", label="Upload Image")
163
  with gr.Accordion("Advanced Settings", open=False):
164
  with gr.Accordion("Clustering", open=False):
165
- colormode = gr.Radio([("COLOR","color"),("B/W", "binary")], value="color", label="Color Mode", show_label=False)
166
  filter_speckle = gr.Slider(0, 128, value=4, step=1, label="Filter Speckle", info="Cleaner")
167
  color_precision = gr.Slider(1, 8, value=6, step=1, label="Color Precision", info="More accurate")
168
  layer_difference = gr.Slider(0, 128, value=16, step=1, label="Gradient Step", info="Less layers")
169
- hierarchical = gr.Radio([("STACKED","stacked"), ("CUTOUT","cutout")], value="stacked", label="Hierarchical Mode",show_label=False)
170
  with gr.Accordion("Curve Fitting", open=False):
171
- mode = gr.Radio([("SPLINE","spline"),("POLYGON", "polygon"), ("PIXEL","none")], value="spline", label="Mode", show_label=False)
172
  corner_threshold = gr.Slider(0, 180, value=60, step=1, label="Corner Threshold", info="Smoother")
173
- length_threshold = gr.Slider(3.5, 10, value=4.0, step=0.1, label="Segment Length", info ="More coarse")
174
  splice_threshold = gr.Slider(0, 180, value=45, step=1, label="Splice Threshold", info="Less accurate")
175
  max_iterations = gr.Slider(1, 20, value=10, step=1, label="Max Iterations", visible=False)
176
  path_precision = gr.Slider(1, 10, value=3, step=1, label="Path Precision", visible=False)
177
  output_text = gr.Textbox(label="Selected Mode", visible=False)
178
  with gr.Row():
179
  clear_button = gr.Button("Clear")
180
- convert_button = gr.Button("✨ Convert to SVG", variant='primary', elem_classes=["generate-btn"])
181
 
182
  with gr.Column():
183
  html = gr.HTML(label="SVG Output")
184
  svg_output = gr.File(label="Download SVG")
185
-
186
- gr.Examples(
187
- examples=examples,
188
- fn=convert_to_vector,
189
- inputs=[image_input],
190
- outputs=[html, svg_output],
191
- cache_examples=False,
192
- run_on_click=True
193
- )
 
194
 
195
  colormode.change(handle_color_mode, inputs=colormode, outputs=output_text)
196
  hierarchical.change(handle_color_mode, inputs=hierarchical, outputs=output_text)
197
  mode.change(handle_color_mode, inputs=mode, outputs=output_text)
 
198
  default_values = {
199
  "color_precision": 6,
200
  "layer_difference": 16
@@ -255,11 +268,10 @@ with gr.Blocks(css=css) as gradio_app:
255
  outputs=[html, svg_output]
256
  )
257
 
258
- # Mount Gradio app to FastAPI at a subpath
259
- from gradio import mount_gradio_app
260
- app = mount_gradio_app(app, gradio_app, path="/gradio")
261
-
262
- # Health check endpoint
263
- @app.get("/")
264
- async def health_check():
265
- return {"status": "healthy"}
 
18
  # Configure CORS to allow requests from Figma plugin
19
  app.add_middleware(
20
  CORSMiddleware,
21
+ allow_origins=["https://www.figma.com", "*"], # Allow Figma and local testing
22
  allow_credentials=True,
23
  allow_methods=["*"],
24
  allow_headers=["*"],
25
  )
26
 
27
+ # VTracer conversion function
28
  def convert_to_vector(
29
  image,
30
  colormode="color",
 
45
  try:
46
  # Save the input image to a temporary file
47
  image.save(input_path)
48
+ logger.info(f"Saved image to {input_path}")
49
 
50
  # Convert the image to SVG using VTracer
51
  vtracer.convert_image_to_svg_py(
 
63
  splice_threshold=int(splice_threshold),
64
  path_precision=int(path_precision)
65
  )
66
+ logger.info(f"Converted image to SVG at {output_path}")
67
 
68
  # Read the SVG output
69
  with open(output_path, "r") as f:
70
  svg_content = f.read()
71
 
72
  return svg_content
73
+ except Exception as e:
74
+ logger.error(f"Error in convert_to_vector: {str(e)}")
75
+ raise HTTPException(status_code=500, detail=f"Conversion failed: {str(e)}")
76
  finally:
77
  # Clean up temporary files
78
  for path in [input_path, output_path]:
79
  if os.path.exists(path):
80
  try:
81
  os.remove(path)
82
+ logger.info(f"Removed {path}")
83
  except Exception as e:
84
  logger.warning(f"Failed to remove {path}: {str(e)}")
85
 
 
100
  path_precision: int = Form(3)
101
  ):
102
  try:
103
+ logger.info("Received request to /convert")
104
  # Read the uploaded image
105
  image_data = await file.read()
106
  image = Image.open(io.BytesIO(image_data))
 
116
  layer_difference,
117
  corner_threshold,
118
  length_threshold,
119
+ max iterations=max_iterations,
120
+ splice_threshold=splice_threshold,
121
+ path_precision=path_precision
122
  )
123
 
124
  return JSONResponse(content={"svg": svg_content})
 
126
  logger.error(f"Error in convert_image: {str(e)}")
127
  return JSONResponse(content={"error": str(e)}, status_code=500)
128
 
129
+ # Health check endpoint
130
+ @app.get("/")
131
+ async def health_check():
132
+ logger.info("Health check requested")
133
+ return {"status": "healthy"}
134
+
135
  # Gradio interface
136
  def handle_color_mode(value):
137
  return value
138
 
139
+ # Check if examples directory exists, else use empty list
140
+ examples_dir = "examples"
141
  examples = [
142
+ os.path.join(examples_dir, f) for f in ["11.jpg", "02.jpg", "03.jpg"]
143
+ if os.path.exists(os.path.join(examples_dir, f))
 
144
  ]
145
 
146
  css = """
 
173
  image_input = gr.Image(type="pil", label="Upload Image")
174
  with gr.Accordion("Advanced Settings", open=False):
175
  with gr.Accordion("Clustering", open=False):
176
+ colormode = gr.Radio([("COLOR", "color"), ("B/W", "binary")], value="color", label="Color Mode", show_label=False)
177
  filter_speckle = gr.Slider(0, 128, value=4, step=1, label="Filter Speckle", info="Cleaner")
178
  color_precision = gr.Slider(1, 8, value=6, step=1, label="Color Precision", info="More accurate")
179
  layer_difference = gr.Slider(0, 128, value=16, step=1, label="Gradient Step", info="Less layers")
180
+ hierarchical = gr.Radio([("STACKED", "stacked"), ("CUTOUT", "cutout")], value="stacked", label="Hierarchical Mode", show_label=False)
181
  with gr.Accordion("Curve Fitting", open=False):
182
+ mode = gr.Radio([("SPLINE", "spline"), ("POLYGON", "polygon"), ("PIXEL", "none")], value="spline", label="Mode", show_label=False)
183
  corner_threshold = gr.Slider(0, 180, value=60, step=1, label="Corner Threshold", info="Smoother")
184
+ length_threshold = gr.Slider(3.5, 10, value=4.0, step=0.1, label="Segment Length", info="More coarse")
185
  splice_threshold = gr.Slider(0, 180, value=45, step=1, label="Splice Threshold", info="Less accurate")
186
  max_iterations = gr.Slider(1, 20, value=10, step=1, label="Max Iterations", visible=False)
187
  path_precision = gr.Slider(1, 10, value=3, step=1, label="Path Precision", visible=False)
188
  output_text = gr.Textbox(label="Selected Mode", visible=False)
189
  with gr.Row():
190
  clear_button = gr.Button("Clear")
191
+ convert_button = gr.Button("✨ Convert to SVG", variant="primary", elem_classes=["generate-btn"])
192
 
193
  with gr.Column():
194
  html = gr.HTML(label="SVG Output")
195
  svg_output = gr.File(label="Download SVG")
196
+
197
+ if examples:
198
+ gr.Examples(
199
+ examples=examples,
200
+ fn=convert_to_vector,
201
+ inputs=[image_input],
202
+ outputs=[html, svg_output],
203
+ cache_examples=False,
204
+ run_on_click=True
205
+ )
206
 
207
  colormode.change(handle_color_mode, inputs=colormode, outputs=output_text)
208
  hierarchical.change(handle_color_mode, inputs=hierarchical, outputs=output_text)
209
  mode.change(handle_color_mode, inputs=mode, outputs=output_text)
210
+
211
  default_values = {
212
  "color_precision": 6,
213
  "layer_difference": 16
 
268
  outputs=[html, svg_output]
269
  )
270
 
271
+ # Mount Gradio app to FastAPI at /gradio
272
+ try:
273
+ from gradio import mount_gradio_app
274
+ app = mount_gradio_app(app, gradio_app, path="/gradio")
275
+ logger.info("Gradio app mounted successfully at /gradio")
276
+ except Exception as e:
277
+ logger.error(f"Failed to mount Gradio app: {str(e)}")