potato commited on
Commit
b3c12b0
ยท
1 Parent(s): e548cc9

fix: requirements.txt, app.py

Browse files
Files changed (2) hide show
  1. app.py +75 -52
  2. requirements.txt +3 -0
app.py CHANGED
@@ -4,32 +4,35 @@ import vtracer
4
  import tempfile
5
  import cairosvg
6
  import re
 
7
  from PIL import Image
8
  from datetime import datetime
9
-
10
- from flask import Flask, request, jsonify, send_from_directory
11
- from flask_cors import CORS
 
 
 
12
 
13
  from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler
14
-
15
  import torchvision.transforms as transforms
16
  from model import Generator
17
 
 
 
 
 
18
  def setup_directories():
 
19
  os.makedirs(SVG_DIR, exist_ok=True)
20
  os.makedirs(THUMBNAIL_DIR, exist_ok=True)
21
  print(f"Directories '{SVG_DIR}' and '{THUMBNAIL_DIR}' are ready.")
22
 
23
- def sanitize_filename(prompt):
24
  """Removes characters that are invalid for filenames."""
25
-
26
  s = re.sub(r'[\\/*?:"<>|]', "", prompt)
 
27
 
28
- return s[:100]
29
-
30
- SVG_DIR = os.path.join(os.getcwd(), 'generated_svgs')
31
- THUMBNAIL_DIR = os.path.join(os.getcwd(), 'thumbnails')
32
- SKETCH_MODEL_WEIGHTS = 'checkpoints/netG_A_latest.pth'
33
 
34
  class ImageToSvgPipeline:
35
  """
@@ -46,7 +49,7 @@ class ImageToSvgPipeline:
46
  def _initialize_rinna_model(self):
47
  print("Loading Rinna Stable Diffusion model...")
48
  model_id = "rinna/japanese-stable-diffusion"
49
-
50
  self.rinna_pipe = StableDiffusionPipeline.from_pretrained(
51
  model_id,
52
  torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
@@ -101,17 +104,16 @@ class ImageToSvgPipeline:
101
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file:
102
  image.save(tmp_file.name)
103
  tmp_path = tmp_file.name
104
-
 
105
  try:
106
- svg_output_path = tmp_path.replace(".png", ".svg")
107
  vtracer.convert_image_to_svg_py(tmp_path, svg_output_path)
108
-
109
  with open(svg_output_path, 'r', encoding='utf-8') as f:
110
  svg_data = f.read()
111
  finally:
112
  if os.path.exists(tmp_path): os.remove(tmp_path)
113
- if 'svg_output_path' in locals() and os.path.exists(svg_output_path): os.remove(svg_output_path)
114
-
115
  print("SVG extraction complete.")
116
  return svg_data
117
 
@@ -121,28 +123,37 @@ class ImageToSvgPipeline:
121
  svg_content = self._extract_svg(sketch_image)
122
  return svg_content
123
 
124
- app = Flask(__name__)
 
125
 
126
- CORS(app, resources={r"/*": {"origins": "*"}})
127
 
128
- pipeline = ImageToSvgPipeline(sketch_model_path=SKETCH_MODEL_WEIGHTS)
 
 
 
 
 
 
129
 
130
- def sanitize_filename(text):
131
- text = re.sub(r'[\\/*?:"<>|]', "", text)
132
- return text.strip()
133
 
134
- @app.route('/generate', methods=['POST'])
135
- def generate_svg():
136
- data = request.json
137
- prompt = data.get('prompt')
138
- if not prompt: return jsonify({"error": "Prompt is required"}), 400
 
 
139
 
140
  negative_prompt = "ไฝŽๅ“่ณชใ€ๆœ€ๆ‚ชใฎๅ“่ณชใ€ไธ‹ๆ‰‹ใชๆ‰‹ใ€ๆŒ‡ใŒ6ๆœฌใ€ๆŒ‡ใŒ4ๆœฌใ€ๅฅ‡ๅฝขใ€้†œใ„ใ€ใผใ‚„ใ‘ใฆใ„ใ‚‹ใ€ใผใ‚„ใ‘ใŸใ€ใ‚ฆใ‚ฉใƒผใ‚ฟใƒผใƒžใƒผใ‚ฏใ€็ฝฒๅใ€ใƒ†ใ‚ญใ‚นใƒˆ"
141
  try:
142
- svg_result = pipeline.process(prompt, negative_prompt)
143
 
 
144
  timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
145
- safe_prompt = sanitize_filename(prompt)[:50]
146
  filename = f"{timestamp}_{safe_prompt}.svg"
147
 
148
  svg_path = os.path.join(SVG_DIR, filename)
@@ -152,17 +163,20 @@ def generate_svg():
152
  thumbnail_path = os.path.join(THUMBNAIL_DIR, filename.replace('.svg', '.png'))
153
  cairosvg.svg2png(bytestring=svg_result.encode('utf-8'), write_to=thumbnail_path, output_width=256, output_height=256)
154
 
155
- return svg_result, 200, {'Content-Type': 'image/svg+xml'}
 
 
156
  except Exception as e:
157
  print(f"An error occurred during generation: {e}")
158
- return jsonify({"error": str(e)}), 500
159
 
160
- @app.route('/gallery', methods=['GET'])
161
- def get_gallery():
162
- try:
163
- page = int(request.args.get('page', 1))
164
- limit = int(request.args.get('limit', 8))
165
 
 
 
 
 
 
 
166
  svg_files = sorted([f for f in os.listdir(SVG_DIR) if f.endswith('.svg')], reverse=True)
167
 
168
  start_index = (page - 1) * limit
@@ -180,31 +194,40 @@ def get_gallery():
180
  })
181
 
182
  has_more = end_index < len(svg_files)
183
- return jsonify({"drawings": drawings, "hasMore": has_more})
184
  except Exception as e:
185
  print(f"Error fetching gallery: {e}")
186
- return jsonify({"error": "Failed to fetch gallery"}), 500
187
 
188
- @app.route('/svgs/<path:filename>')
189
- def get_svg(filename):
190
- return send_from_directory(SVG_DIR, filename)
191
 
192
- @app.route('/thumbnails/<path:filename>')
193
- def get_thumbnail(filename):
194
- return send_from_directory(THUMBNAIL_DIR, filename)
195
-
196
- @app.route('/drawings/<path:filename>', methods=['DELETE'])
197
- def delete_drawing_file(filename):
198
  try:
 
 
 
 
199
  svg_path = os.path.join(SVG_DIR, filename)
200
  thumb_path = os.path.join(THUMBNAIL_DIR, filename.replace('.svg', '.png'))
 
 
 
 
201
  if os.path.exists(svg_path): os.remove(svg_path)
202
  if os.path.exists(thumb_path): os.remove(thumb_path)
203
- return jsonify({"message": f"Successfully deleted {filename}"})
 
204
  except Exception as e:
205
  print(f"Error deleting file: {e}")
206
- return jsonify({"error": "Failed to delete file"}), 500
 
 
 
 
207
 
208
  if __name__ == '__main__':
209
- print("Starting Flask server...")
210
- app.run(host='0.0.0.0', port=5000)
 
4
  import tempfile
5
  import cairosvg
6
  import re
7
+ import uvicorn
8
  from PIL import Image
9
  from datetime import datetime
10
+ from fastapi import FastAPI, HTTPException, Request
11
+ from fastapi.responses import FileResponse, JSONResponse, Response
12
+ from fastapi.staticfiles import StaticFiles
13
+ from fastapi.middleware.cors import CORSMiddleware
14
+ from pydantic import BaseModel
15
+ from typing import Optional
16
 
17
  from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler
 
18
  import torchvision.transforms as transforms
19
  from model import Generator
20
 
21
+ SVG_DIR = os.path.join(os.getcwd(), 'generated_svgs')
22
+ THUMBNAIL_DIR = os.path.join(os.getcwd(), 'thumbnails')
23
+ SKETCH_MODEL_WEIGHTS = 'checkpoints/netG_A_latest.pth'
24
+
25
  def setup_directories():
26
+ """Creates necessary directories if they don't exist."""
27
  os.makedirs(SVG_DIR, exist_ok=True)
28
  os.makedirs(THUMBNAIL_DIR, exist_ok=True)
29
  print(f"Directories '{SVG_DIR}' and '{THUMBNAIL_DIR}' are ready.")
30
 
31
+ def sanitize_filename(prompt: str) -> str:
32
  """Removes characters that are invalid for filenames."""
 
33
  s = re.sub(r'[\\/*?:"<>|]', "", prompt)
34
+ return s.strip()[:100]
35
 
 
 
 
 
 
36
 
37
  class ImageToSvgPipeline:
38
  """
 
49
  def _initialize_rinna_model(self):
50
  print("Loading Rinna Stable Diffusion model...")
51
  model_id = "rinna/japanese-stable-diffusion"
52
+
53
  self.rinna_pipe = StableDiffusionPipeline.from_pretrained(
54
  model_id,
55
  torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
 
104
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file:
105
  image.save(tmp_file.name)
106
  tmp_path = tmp_file.name
107
+
108
+ svg_output_path = tmp_path.replace(".png", ".svg")
109
  try:
 
110
  vtracer.convert_image_to_svg_py(tmp_path, svg_output_path)
 
111
  with open(svg_output_path, 'r', encoding='utf-8') as f:
112
  svg_data = f.read()
113
  finally:
114
  if os.path.exists(tmp_path): os.remove(tmp_path)
115
+ if os.path.exists(svg_output_path): os.remove(svg_output_path)
116
+
117
  print("SVG extraction complete.")
118
  return svg_data
119
 
 
123
  svg_content = self._extract_svg(sketch_image)
124
  return svg_content
125
 
126
+ setup_directories()
127
+ pipeline = ImageToSvgPipeline(sketch_model_path=SKETCH_MODEL_WEIGHTS)
128
 
129
+ app = FastAPI()
130
 
131
+ app.add_middleware(
132
+ CORSMiddleware,
133
+ allow_origins=["*"], # Allows all origins
134
+ allow_credentials=True,
135
+ allow_methods=["*"], # Allows all methods
136
+ allow_headers=["*"], # Allows all headers
137
+ )
138
 
139
+ class GenerateRequest(BaseModel):
140
+ prompt: str
 
141
 
142
+ @app.post("/generate")
143
+ async def generate_svg(item: GenerateRequest):
144
+ """
145
+ Receives a prompt, generates an SVG, saves it, and returns the SVG content.
146
+ """
147
+ if not item.prompt:
148
+ raise HTTPException(status_code=400, detail="Prompt is required")
149
 
150
  negative_prompt = "ไฝŽๅ“่ณชใ€ๆœ€ๆ‚ชใฎๅ“่ณชใ€ไธ‹ๆ‰‹ใชๆ‰‹ใ€ๆŒ‡ใŒ6ๆœฌใ€ๆŒ‡ใŒ4ๆœฌใ€ๅฅ‡ๅฝขใ€้†œใ„ใ€ใผใ‚„ใ‘ใฆใ„ใ‚‹ใ€ใผใ‚„ใ‘ใŸใ€ใ‚ฆใ‚ฉใƒผใ‚ฟใƒผใƒžใƒผใ‚ฏใ€็ฝฒๅใ€ใƒ†ใ‚ญใ‚นใƒˆ"
151
  try:
152
+ svg_result = pipeline.process(item.prompt, negative_prompt)
153
 
154
+ # Save the SVG and its thumbnail
155
  timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
156
+ safe_prompt = sanitize_filename(item.prompt)[:50]
157
  filename = f"{timestamp}_{safe_prompt}.svg"
158
 
159
  svg_path = os.path.join(SVG_DIR, filename)
 
163
  thumbnail_path = os.path.join(THUMBNAIL_DIR, filename.replace('.svg', '.png'))
164
  cairosvg.svg2png(bytestring=svg_result.encode('utf-8'), write_to=thumbnail_path, output_width=256, output_height=256)
165
 
166
+ # Return the SVG data directly in the response
167
+ return Response(content=svg_result, media_type="image/svg+xml")
168
+
169
  except Exception as e:
170
  print(f"An error occurred during generation: {e}")
171
+ raise HTTPException(status_code=500, detail=str(e))
172
 
 
 
 
 
 
173
 
174
+ @app.get("/gallery")
175
+ def get_gallery(page: int = 1, limit: int = 8):
176
+ """
177
+ Returns a paginated list of generated drawings.
178
+ """
179
+ try:
180
  svg_files = sorted([f for f in os.listdir(SVG_DIR) if f.endswith('.svg')], reverse=True)
181
 
182
  start_index = (page - 1) * limit
 
194
  })
195
 
196
  has_more = end_index < len(svg_files)
197
+ return {"drawings": drawings, "hasMore": has_more}
198
  except Exception as e:
199
  print(f"Error fetching gallery: {e}")
200
+ raise HTTPException(status_code=500, detail="Failed to fetch gallery")
201
 
 
 
 
202
 
203
+ @app.delete("/drawings/{filename}")
204
+ def delete_drawing_file(filename: str):
205
+ """
206
+ Deletes a specific SVG and its corresponding thumbnail.
207
+ """
 
208
  try:
209
+ # Sanitize filename to prevent directory traversal
210
+ if ".." in filename or "/" in filename:
211
+ raise HTTPException(status_code=400, detail="Invalid filename")
212
+
213
  svg_path = os.path.join(SVG_DIR, filename)
214
  thumb_path = os.path.join(THUMBNAIL_DIR, filename.replace('.svg', '.png'))
215
+
216
+ if not os.path.exists(svg_path):
217
+ raise HTTPException(status_code=404, detail="File not found")
218
+
219
  if os.path.exists(svg_path): os.remove(svg_path)
220
  if os.path.exists(thumb_path): os.remove(thumb_path)
221
+
222
+ return JSONResponse(content={"message": f"Successfully deleted {filename}"})
223
  except Exception as e:
224
  print(f"Error deleting file: {e}")
225
+ raise HTTPException(status_code=500, detail="Failed to delete file")
226
+
227
+ app.mount("/svgs", StaticFiles(directory=SVG_DIR), name="svgs")
228
+ app.mount("/thumbnails", StaticFiles(directory=THUMBNAIL_DIR), name="thumbnails")
229
+
230
 
231
  if __name__ == '__main__':
232
+ print("Starting FastAPI server...")
233
+ uvicorn.run(app, host='0.0.0.0', port=5000)
requirements.txt CHANGED
@@ -1,5 +1,7 @@
1
  flask
2
  Flask-Cors
 
 
3
  torch
4
  diffusers==0.35.1
5
  transformers==4.56.2
@@ -11,3 +13,4 @@ torchvision==0.23.0
11
  sentencepiece==0.2.0
12
  scipy
13
  numpy
 
 
1
  flask
2
  Flask-Cors
3
+ fastapi
4
+ uvicorn[standard]
5
  torch
6
  diffusers==0.35.1
7
  transformers==4.56.2
 
13
  sentencepiece==0.2.0
14
  scipy
15
  numpy
16
+ python-multipart