potato commited on
Commit
8da5801
ยท
1 Parent(s): 71bde1b

remove diffusion model's default pbar, add callback function

Browse files
Files changed (2) hide show
  1. app.py +215 -105
  2. requirements.txt +6 -0
app.py CHANGED
@@ -4,41 +4,41 @@ import vtracer
4
  import tempfile
5
  import cairosvg
6
  import re
7
- import uvicorn
8
  from PIL import Image
9
  from datetime import datetime
10
- from fastapi import FastAPI, HTTPException, Request
11
- from fastapi.responses import FileResponse, JSONResponse, Response
12
- from fastapi.staticfiles import StaticFiles
13
- from fastapi.middleware.cors import CORSMiddleware
14
- from pydantic import BaseModel
15
- from typing import Optional
 
 
16
 
17
  from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler
 
 
18
  import torchvision.transforms as transforms
19
  from model import Generator
 
20
 
21
- SVG_DIR = os.path.join(os.getcwd(), 'generated_svgs')
22
- THUMBNAIL_DIR = os.path.join(os.getcwd(), 'thumbnails')
23
- SKETCH_MODEL_WEIGHTS = 'checkpoints/netG_A_latest.pth'
24
 
25
  def setup_directories():
26
- """Creates necessary directories if they don't exist."""
27
- os.makedirs(SVG_DIR, exist_ok=True)
28
  os.makedirs(THUMBNAIL_DIR, exist_ok=True)
29
- print(f"Directories '{SVG_DIR}' and '{THUMBNAIL_DIR}' are ready.")
30
 
31
- def sanitize_filename(prompt: str) -> str:
32
  """Removes characters that are invalid for filenames."""
33
  s = re.sub(r'[\\/*?:"<>|]', "", prompt)
34
- return s.strip()[:100]
35
 
 
 
 
36
 
37
  class ImageToSvgPipeline:
38
- """
39
- A class to handle the entire pipeline from text prompt to SVG.
40
- Initializes models once to be reused.
41
- """
42
  def __init__(self, sketch_model_path: str):
43
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
44
  print(f"Using device: {self.device}")
@@ -49,7 +49,7 @@ class ImageToSvgPipeline:
49
  def _initialize_rinna_model(self):
50
  print("Loading Rinna Stable Diffusion model...")
51
  model_id = "rinna/japanese-stable-diffusion"
52
-
53
  self.rinna_pipe = StableDiffusionPipeline.from_pretrained(
54
  model_id,
55
  torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
@@ -60,8 +60,21 @@ class ImageToSvgPipeline:
60
  )
61
  self.rinna_pipe.tokenizer.model_max_length = 77
62
  self.rinna_pipe.to(self.device)
 
63
  print("Rinna model loaded.")
64
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  def _initialize_sketch_model(self, model_path: str):
66
  print(f"Loading Sketch Generator model from {model_path}...")
67
  if not os.path.exists(model_path):
@@ -77,18 +90,19 @@ class ImageToSvgPipeline:
77
  ])
78
  print("Sketch model loaded.")
79
 
80
- def _generate_image(self, prompt: str, negative_prompt: str, steps: int = 8) -> Image.Image:
81
  print(f"Generating image for prompt: '{prompt}'")
82
  with torch.no_grad():
83
- image = self.rinna_pipe(
84
  prompt,
85
  negative_prompt=negative_prompt,
86
  num_inference_steps=steps,
87
  guidance_scale=7.5,
88
- width=512,
89
- height=512,
90
- ).images[0]
91
- return image
 
92
 
93
  def _convert_to_sketch(self, image: Image.Image) -> Image.Image:
94
  print("Converting image to sketch...")
@@ -104,127 +118,223 @@ class ImageToSvgPipeline:
104
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file:
105
  image.save(tmp_file.name)
106
  tmp_path = tmp_file.name
107
-
108
- svg_output_path = tmp_path.replace(".png", ".svg")
109
  try:
 
110
  vtracer.convert_image_to_svg_py(tmp_path, svg_output_path)
111
  with open(svg_output_path, 'r', encoding='utf-8') as f:
112
  svg_data = f.read()
113
  finally:
114
  if os.path.exists(tmp_path): os.remove(tmp_path)
115
- if os.path.exists(svg_output_path): os.remove(svg_output_path)
116
-
117
  print("SVG extraction complete.")
118
  return svg_data
119
 
120
- def process(self, prompt: str, negative_prompt: str) -> str:
121
- generated_image = self._generate_image(prompt, negative_prompt)
122
- sketch_image = self._convert_to_sketch(generated_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  svg_content = self._extract_svg(sketch_image)
124
- return svg_content
125
-
126
- setup_directories()
127
- pipeline = ImageToSvgPipeline(sketch_model_path=SKETCH_MODEL_WEIGHTS)
128
-
129
- app = FastAPI()
130
 
131
- app.add_middleware(
132
- CORSMiddleware,
133
- allow_origins=["*"], # Allows all origins
134
- allow_credentials=True,
135
- allow_methods=["*"], # Allows all methods
136
- allow_headers=["*"], # Allows all headers
137
- )
138
 
139
- class GenerateRequest(BaseModel):
140
- prompt: str
141
 
142
- @app.post("/generate")
143
- async def generate_svg(item: GenerateRequest):
144
- """
145
- Receives a prompt, generates an SVG, saves it, and returns the SVG content.
146
- """
147
- if not item.prompt:
148
- raise HTTPException(status_code=400, detail="Prompt is required")
149
-
150
- negative_prompt = "ไฝŽๅ“่ณชใ€ๆœ€ๆ‚ชใฎๅ“่ณชใ€ๅฅ‡ๅฝขใ€้†œใ„ใ€ใผใ‚„ใ‘ใฆใ„ใ‚‹ใ€ใผใ‚„ใ‘ใŸใ€ใ‚ฆใ‚ฉใƒผใ‚ฟใƒผใƒžใƒผใ‚ฏใ€็ฝฒๅใ€ใƒ†ใ‚ญใ‚นใƒˆใ€ใƒ•ใƒฌใƒผใƒ ใ‹ใ‚‰ๅค–ใ‚ŒใŸใ€ๆ‰‹่ถณใŒๅˆ‡ใ‚Œใฆใ„ใ‚‹ใ€ใ‚ฏใƒญใƒƒใƒ—ใ•ใ‚ŒใŸใ€่ขซๅ†™ไฝ“ใŒๅˆ‡ใ‚Šๅ–ใ‚‰ใ‚Œใฆใ„ใ‚‹ใ€ๆง‹ๆˆใŒๆ‚ชใ„ใ€็„ฆ็‚นใŒๅˆใฃใฆใ„ใชใ„"
151
- try:
152
- svg_result = pipeline.process(item.prompt, negative_prompt)
153
 
154
- # Save the SVG and its thumbnail
155
- timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
156
- safe_prompt = sanitize_filename(item.prompt)[:50]
157
- filename = f"{timestamp}_{safe_prompt}.svg"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
- svg_path = os.path.join(SVG_DIR, filename)
160
- with open(svg_path, 'w', encoding='utf-8') as f:
161
- f.write(svg_result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
- thumbnail_path = os.path.join(THUMBNAIL_DIR, filename.replace('.svg', '.png'))
164
- cairosvg.svg2png(bytestring=svg_result.encode('utf-8'), write_to=thumbnail_path, output_width=256, output_height=256)
165
 
166
- # Return the SVG data directly in the response
167
- return Response(content=svg_result, media_type="image/svg+xml")
 
 
 
 
 
168
 
169
- except Exception as e:
170
- print(f"An error occurred during generation: {e}")
171
- raise HTTPException(status_code=500, detail=str(e))
172
 
173
 
174
- @app.get("/gallery")
175
- def get_gallery(page: int = 1, limit: int = 8):
176
- """
177
- Returns a paginated list of generated drawings.
178
- """
179
  try:
180
- svg_files = sorted([f for f in os.listdir(SVG_DIR) if f.endswith('.svg')], reverse=True)
 
181
 
 
182
  start_index = (page - 1) * limit
183
  end_index = start_index + limit
184
- paginated_files = svg_files[start_index:end_index]
185
 
186
  drawings = []
187
  for filename in paginated_files:
188
- prompt_match = re.match(r"\d+_(.+)\.svg", filename)
189
  prompt = prompt_match.group(1).replace('_', ' ') if prompt_match else "Prompt not found"
190
  drawings.append({
191
  "filename": filename,
192
- "thumbnail": f"/thumbnails/{filename.replace('.svg', '.png')}",
193
  "prompt": prompt
194
  })
195
 
196
- has_more = end_index < len(svg_files)
197
- return {"drawings": drawings, "hasMore": has_more}
198
  except Exception as e:
199
  print(f"Error fetching gallery: {e}")
200
- raise HTTPException(status_code=500, detail="Failed to fetch gallery")
201
-
202
-
203
- @app.delete("/drawings/{filename}")
204
- def delete_drawing_file(filename: str):
205
- """
206
- Deletes a specific SVG and its corresponding thumbnail.
207
- """
208
- try:
209
- # Sanitize filename to prevent directory traversal
210
- if ".." in filename or "/" in filename:
211
- raise HTTPException(status_code=400, detail="Invalid filename")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
- svg_path = os.path.join(SVG_DIR, filename)
214
- thumb_path = os.path.join(THUMBNAIL_DIR, filename.replace('.svg', '.png'))
 
215
 
216
- if not os.path.exists(svg_path):
217
- raise HTTPException(status_code=404, detail="File not found")
 
218
 
219
- if os.path.exists(svg_path): os.remove(svg_path)
 
 
 
 
 
220
  if os.path.exists(thumb_path): os.remove(thumb_path)
221
-
222
- return JSONResponse(content={"message": f"Successfully deleted {filename}"})
223
  except Exception as e:
224
  print(f"Error deleting file: {e}")
225
- raise HTTPException(status_code=500, detail="Failed to delete file")
226
 
227
- app.mount("/svgs", StaticFiles(directory=SVG_DIR), name="svgs")
228
  app.mount("/thumbnails", StaticFiles(directory=THUMBNAIL_DIR), name="thumbnails")
229
 
230
 
 
4
  import tempfile
5
  import cairosvg
6
  import re
 
7
  from PIL import Image
8
  from datetime import datetime
9
+ import gc
10
+ import json
11
+ import time
12
+ import queue
13
+ import threading
14
+
15
+ from flask import Flask, request, jsonify, send_from_directory, Response, stream_with_context
16
+ from flask_cors import CORS
17
 
18
  from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler
19
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
20
+
21
  import torchvision.transforms as transforms
22
  from model import Generator
23
+ from utils import process_svg
24
 
25
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
 
 
26
 
27
  def setup_directories():
28
+ os.makedirs(STROKES_DIR, exist_ok=True)
 
29
  os.makedirs(THUMBNAIL_DIR, exist_ok=True)
30
+ print(f"Directories '{STROKES_DIR}' and '{THUMBNAIL_DIR}' are ready.")
31
 
32
+ def sanitize_filename(prompt):
33
  """Removes characters that are invalid for filenames."""
34
  s = re.sub(r'[\\/*?:"<>|]', "", prompt)
35
+ return s[:100]
36
 
37
+ STROKES_DIR = os.path.join(os.getcwd(), 'strokes')
38
+ THUMBNAIL_DIR = os.path.join(os.getcwd(), 'thumbnails')
39
+ SKETCH_MODEL_WEIGHTS = os.path.join('checkpoints', 'netG_A_latest.pth')
40
 
41
  class ImageToSvgPipeline:
 
 
 
 
42
  def __init__(self, sketch_model_path: str):
43
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
44
  print(f"Using device: {self.device}")
 
49
  def _initialize_rinna_model(self):
50
  print("Loading Rinna Stable Diffusion model...")
51
  model_id = "rinna/japanese-stable-diffusion"
52
+
53
  self.rinna_pipe = StableDiffusionPipeline.from_pretrained(
54
  model_id,
55
  torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
 
60
  )
61
  self.rinna_pipe.tokenizer.model_max_length = 77
62
  self.rinna_pipe.to(self.device)
63
+ self.rinna_pipe.set_progress_bar_config(disable=True)
64
  print("Rinna model loaded.")
65
 
66
+ def unload_rinna_model(self):
67
+ if hasattr(self, 'rinna_pipe'):
68
+ print("Unloading Rinna Stable Diffusion model...")
69
+ del self.rinna_pipe
70
+ gc.collect()
71
+ if self.device == "cuda":
72
+ torch.cuda.empty_cache()
73
+ print("GPU memory cache cleared.")
74
+ print("Rinna model unloaded successfully.")
75
+ else:
76
+ print("Rinna model is not currently loaded.")
77
+
78
  def _initialize_sketch_model(self, model_path: str):
79
  print(f"Loading Sketch Generator model from {model_path}...")
80
  if not os.path.exists(model_path):
 
90
  ])
91
  print("Sketch model loaded.")
92
 
93
+ def _generate_image(self, prompt: str, negative_prompt: str, steps: int = 30, callback=None) -> Image.Image:
94
  print(f"Generating image for prompt: '{prompt}'")
95
  with torch.no_grad():
96
+ output: StableDiffusionPipelineOutput = self.rinna_pipe(
97
  prompt,
98
  negative_prompt=negative_prompt,
99
  num_inference_steps=steps,
100
  guidance_scale=7.5,
101
+ width=720,
102
+ height=720,
103
+ callback_on_step_end=callback
104
+ )
105
+ return output.images[0]
106
 
107
  def _convert_to_sketch(self, image: Image.Image) -> Image.Image:
108
  print("Converting image to sketch...")
 
118
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file:
119
  image.save(tmp_file.name)
120
  tmp_path = tmp_file.name
121
+
 
122
  try:
123
+ svg_output_path = tmp_path.replace(".png", ".svg")
124
  vtracer.convert_image_to_svg_py(tmp_path, svg_output_path)
125
  with open(svg_output_path, 'r', encoding='utf-8') as f:
126
  svg_data = f.read()
127
  finally:
128
  if os.path.exists(tmp_path): os.remove(tmp_path)
129
+ if 'svg_output_path' in locals() and os.path.exists(svg_output_path): os.remove(svg_output_path)
130
+
131
  print("SVG extraction complete.")
132
  return svg_data
133
 
134
+ def process(self, prompt: str, img_path: str, negative_prompt: str, callback=None):
135
+ """Processes the image generation and conversion, with progress callbacks."""
136
+ def _callback(progress, step_name):
137
+ if callback:
138
+ callback(progress, step_name)
139
+
140
+ generated_img = None
141
+ if img_path is None:
142
+ total_diffusion_steps = 30
143
+
144
+ def diffusion_callback(pipe, step_index, timestep, callback_kwargs):
145
+ progress = int(5 + ((step_index + 1) / total_diffusion_steps) * 75)
146
+ _callback(progress, "Generating image...")
147
+ return callback_kwargs
148
+
149
+ _callback(5, "Starting image generation...")
150
+ generated_img = self._generate_image(
151
+ prompt,
152
+ negative_prompt,
153
+ steps=total_diffusion_steps,
154
+ callback=diffusion_callback
155
+ )
156
+ gc.collect()
157
+ torch.cuda.empty_cache()
158
+ _callback(80, "Base image generated.")
159
+ img_to_process = generated_img
160
+ else:
161
+ generated_img = Image.open(img_path)
162
+ img_to_process = generated_img
163
+ _callback(80, "Image loaded.")
164
+
165
+ _callback(85, "Converting to sketch...")
166
+ sketch_image = self._convert_to_sketch(img_to_process)
167
+
168
+ _callback(90, "Vectorizing sketch...")
169
  svg_content = self._extract_svg(sketch_image)
170
+ _callback(95, "SVG extracted.")
 
 
 
 
 
171
 
172
+ return svg_content, generated_img
 
 
 
 
 
 
173
 
 
 
174
 
175
+ app = Flask(__name__)
176
+ CORS(app, resources={r"/*": {"origins": "*"}})
177
+ pipeline = ImageToSvgPipeline(sketch_model_path=SKETCH_MODEL_WEIGHTS)
 
 
 
 
 
 
 
 
178
 
179
+ @app.after_request
180
+ def add_ngrok_header(response):
181
+ response.headers['ngrok-skip-browser-warning'] = 'true'
182
+ return response
183
+
184
+ @app.route('/generate', methods=['GET'])
185
+ def generate_stroke():
186
+ prompt = request.args.get('prompt')
187
+ if not prompt:
188
+ return jsonify({"error": "Prompt is required"}), 400
189
+
190
+ negative_prompt = (
191
+ "ไฝŽๅ“่ณชใ€ๆœ€ๆ‚ชใฎๅ“่ณชใ€ๅฅ‡ๅฝขใ€้†œใ„ใ€ใผใ‚„ใ‘ใฆใ„ใ‚‹ใ€ใผใ‚„ใ‘ใŸใ€"
192
+ "ใ‚ฆใ‚ฉใƒผใ‚ฟใƒผใƒžใƒผใ‚ฏใ€็ฝฒๅใ€ใƒ†ใ‚ญใ‚นใƒˆใ€ใƒ•ใƒฌใƒผใƒ ใ‹ใ‚‰ๅค–ใ‚ŒใŸใ€"
193
+ "ๆ‰‹่ถณใŒๅˆ‡ใ‚Œใฆใ„ใ‚‹ใ€ใ‚ฏใƒญใƒƒใƒ—ใ•ใ‚ŒใŸใ€่ขซๅ†™ไฝ“ใŒๅˆ‡ใ‚Šๅ–ใ‚‰ใ‚Œใฆใ„ใ‚‹ใ€"
194
+ "ๆง‹ๆˆใŒๆ‚ชใ„ใ€็„ฆ็‚นใŒๅˆใฃใฆใ„ใชใ„"
195
+ )
196
+
197
+ q = queue.Queue()
198
+
199
+ def worker():
200
+ """Runs the long-running task in a separate thread and puts progress into the queue."""
201
+ start_time = time.time()
202
+
203
+ def progress_callback(progress, step):
204
+ print(f"Progress: {progress}% - {step}")
205
+ data = json.dumps({"progress": progress, "step": step})
206
+ q.put(data)
207
 
208
+ try:
209
+ progress_callback(5, "Initializing...")
210
+
211
+ svg_result, generated_image = pipeline.process(prompt, None, negative_prompt, callback=progress_callback)
212
+
213
+ progress_callback(98, "Finalizing and saving...")
214
+
215
+ timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
216
+ safe_prompt = sanitize_filename(prompt)[:60]
217
+ filename_base = f"{timestamp}_{safe_prompt}"
218
+
219
+ stroke_path = os.path.join(STROKES_DIR, f"{filename_base}.json")
220
+ stroke = process_svg(svg_result, "file")
221
+ with open(stroke_path, 'w', encoding='utf-8') as f:
222
+ json.dump(stroke, f, ensure_ascii=False, indent=2)
223
+
224
+ if generated_image:
225
+ thumbnail_path = os.path.join(THUMBNAIL_DIR, f"{filename_base}.png")
226
+ cairosvg.svg2png(bytestring=svg_result.encode('utf-8'), write_to=thumbnail_path, output_width=256, output_height=256)
227
+
228
+ final_data = json.dumps({"progress": 100, "result": stroke, "step": "Complete!"})
229
+ q.put(final_data)
230
+ end_time = time.time()
231
+ print(f"Total generation time: {end_time - start_time:.2f} seconds")
232
+
233
+ except Exception as e:
234
+ print(f"Error during generation stream: {e}")
235
+ error_data = json.dumps({"error": str(e), "progress": 100})
236
+ q.put(error_data)
237
+ finally:
238
+ q.put(None)
239
 
240
+ threading.Thread(target=worker).start()
 
241
 
242
+ def generate():
243
+ """This generator reads from the queue and yields data to the client."""
244
+ while True:
245
+ item = q.get()
246
+ if item is None:
247
+ break
248
+ yield f"data: {item}\n\n"
249
 
250
+ return Response(stream_with_context(generate()), mimetype='text/event-stream')
 
 
251
 
252
 
253
+ @app.route('/gallery', methods=['GET'])
254
+ def get_gallery():
 
 
 
255
  try:
256
+ page = int(request.args.get('page', 1))
257
+ limit = int(request.args.get('limit', 8))
258
 
259
+ strokes_files = sorted([f for f in os.listdir(STROKES_DIR) if f.endswith('.json')], reverse=True)
260
  start_index = (page - 1) * limit
261
  end_index = start_index + limit
262
+ paginated_files = strokes_files[start_index:end_index]
263
 
264
  drawings = []
265
  for filename in paginated_files:
266
+ prompt_match = re.match(r"\d+_(.+)\.json", filename)
267
  prompt = prompt_match.group(1).replace('_', ' ') if prompt_match else "Prompt not found"
268
  drawings.append({
269
  "filename": filename,
270
+ "thumbnail": f"/thumbnails/{filename.replace('.json', '.png')}",
271
  "prompt": prompt
272
  })
273
 
274
+ has_more = end_index < len(strokes_files)
275
+ return jsonify({"drawings": drawings, "hasMore": has_more})
276
  except Exception as e:
277
  print(f"Error fetching gallery: {e}")
278
+ return jsonify({"error": "Failed to fetch gallery"}), 500
279
+
280
+ @app.route('/add_svg', methods=['POST'])
281
+ def add_svg():
282
+ data = request.json
283
+ folder_path = data.get('folderPath').strip()
284
+ count = 0
285
+ for file in os.listdir(folder_path):
286
+ file_path = os.path.join(folder_path, file)
287
+ stroke_path = os.path.join(STROKES_DIR, file.replace('.svg', '.json'))
288
+ stroke = process_svg(file_path, "path")
289
+ with open(stroke_path, 'w', encoding='utf-8') as f:
290
+ json.dump(stroke, f, ensure_ascii=False, indent=2)
291
+ thumbnail_path = os.path.join(THUMBNAIL_DIR, file.replace('.svg', '.png'))
292
+ cairosvg.svg2png(url=file_path, write_to=thumbnail_path, output_width=256, output_height=256)
293
+ count += 1
294
+ return jsonify({"status": "success", "message": f"Processed {count} SVG files."})
295
+
296
+ @app.route('/add_img', methods=['POST'])
297
+ def add_img():
298
+ data = request.json
299
+ folder_path = data.get('folderPath').strip()
300
+ count = 0
301
+ pipeline.unload_rinna_model()
302
+ for file in os.listdir(folder_path):
303
+ file_path = os.path.join(folder_path, file)
304
+ svg_result, _ = pipeline.process(None, file_path, None)
305
+ timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
306
+ filename = f"{timestamp}_{file.replace('.jpg', '.json').replace('.png', '.json')}"
307
+ stroke_path = os.path.join(STROKES_DIR, filename)
308
+ stroke = process_svg(svg_result, "file")
309
+ with open(stroke_path, 'w', encoding='utf-8') as f:
310
+ json.dump(stroke, f, ensure_ascii=False, indent=2)
311
+ thumbnail_path = os.path.join(THUMBNAIL_DIR, filename.replace('.json', '.png'))
312
+ cairosvg.svg2png(bytestring=svg_result.encode('utf-8'), write_to=thumbnail_path, output_width=256, output_height=256)
313
+ count += 1
314
+ pipeline._initialize_rinna_model()
315
+ return jsonify({"status": "success", "message": f"Processed {count} image files."})
316
 
317
+ @app.route('/strokes/<path:filename>')
318
+ def get_strokes(filename):
319
+ return send_from_directory(STROKES_DIR, filename)
320
 
321
+ @app.route('/thumbnails/<path:filename>')
322
+ def get_thumbnail(filename):
323
+ return send_from_directory(THUMBNAIL_DIR, filename)
324
 
325
+ @app.route('/drawings/<path:filename>', methods=['DELETE'])
326
+ def delete_drawing_file(filename):
327
+ try:
328
+ json_path = os.path.join(STROKES_DIR, filename)
329
+ thumb_path = os.path.join(THUMBNAIL_DIR, filename.replace('.json', '.png'))
330
+ if os.path.exists(json_path): os.remove(json_path)
331
  if os.path.exists(thumb_path): os.remove(thumb_path)
332
+ return jsonify({"message": f"Successfully deleted {filename}"})
 
333
  except Exception as e:
334
  print(f"Error deleting file: {e}")
335
+ return jsonify({"error": "Failed to delete file"}), 500
336
 
337
+ app.mount("/strokes", StaticFiles(directory=STROKES_DIR), name="strokes")
338
  app.mount("/thumbnails", StaticFiles(directory=THUMBNAIL_DIR), name="thumbnails")
339
 
340
 
requirements.txt CHANGED
@@ -14,3 +14,9 @@ sentencepiece==0.2.0
14
  scipy
15
  numpy
16
  python-multipart
 
 
 
 
 
 
 
14
  scipy
15
  numpy
16
  python-multipart
17
+ opencv-python
18
+ fast_tsp
19
+ python_tsp
20
+ lxml
21
+ svgpathtools
22
+ "huggingface_hub[cli]"