Spaces:
Runtime error
Runtime error
potato
commited on
Commit
ยท
b3c12b0
1
Parent(s):
e548cc9
fix: requirements.txt, app.py
Browse files- app.py +75 -52
- requirements.txt +3 -0
app.py
CHANGED
|
@@ -4,32 +4,35 @@ import vtracer
|
|
| 4 |
import tempfile
|
| 5 |
import cairosvg
|
| 6 |
import re
|
|
|
|
| 7 |
from PIL import Image
|
| 8 |
from datetime import datetime
|
| 9 |
-
|
| 10 |
-
from
|
| 11 |
-
from
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler
|
| 14 |
-
|
| 15 |
import torchvision.transforms as transforms
|
| 16 |
from model import Generator
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
def setup_directories():
|
|
|
|
| 19 |
os.makedirs(SVG_DIR, exist_ok=True)
|
| 20 |
os.makedirs(THUMBNAIL_DIR, exist_ok=True)
|
| 21 |
print(f"Directories '{SVG_DIR}' and '{THUMBNAIL_DIR}' are ready.")
|
| 22 |
|
| 23 |
-
def sanitize_filename(prompt):
|
| 24 |
"""Removes characters that are invalid for filenames."""
|
| 25 |
-
|
| 26 |
s = re.sub(r'[\\/*?:"<>|]', "", prompt)
|
|
|
|
| 27 |
|
| 28 |
-
return s[:100]
|
| 29 |
-
|
| 30 |
-
SVG_DIR = os.path.join(os.getcwd(), 'generated_svgs')
|
| 31 |
-
THUMBNAIL_DIR = os.path.join(os.getcwd(), 'thumbnails')
|
| 32 |
-
SKETCH_MODEL_WEIGHTS = 'checkpoints/netG_A_latest.pth'
|
| 33 |
|
| 34 |
class ImageToSvgPipeline:
|
| 35 |
"""
|
|
@@ -46,7 +49,7 @@ class ImageToSvgPipeline:
|
|
| 46 |
def _initialize_rinna_model(self):
|
| 47 |
print("Loading Rinna Stable Diffusion model...")
|
| 48 |
model_id = "rinna/japanese-stable-diffusion"
|
| 49 |
-
|
| 50 |
self.rinna_pipe = StableDiffusionPipeline.from_pretrained(
|
| 51 |
model_id,
|
| 52 |
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
|
|
@@ -101,17 +104,16 @@ class ImageToSvgPipeline:
|
|
| 101 |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file:
|
| 102 |
image.save(tmp_file.name)
|
| 103 |
tmp_path = tmp_file.name
|
| 104 |
-
|
|
|
|
| 105 |
try:
|
| 106 |
-
svg_output_path = tmp_path.replace(".png", ".svg")
|
| 107 |
vtracer.convert_image_to_svg_py(tmp_path, svg_output_path)
|
| 108 |
-
|
| 109 |
with open(svg_output_path, 'r', encoding='utf-8') as f:
|
| 110 |
svg_data = f.read()
|
| 111 |
finally:
|
| 112 |
if os.path.exists(tmp_path): os.remove(tmp_path)
|
| 113 |
-
if
|
| 114 |
-
|
| 115 |
print("SVG extraction complete.")
|
| 116 |
return svg_data
|
| 117 |
|
|
@@ -121,28 +123,37 @@ class ImageToSvgPipeline:
|
|
| 121 |
svg_content = self._extract_svg(sketch_image)
|
| 122 |
return svg_content
|
| 123 |
|
| 124 |
-
|
|
|
|
| 125 |
|
| 126 |
-
|
| 127 |
|
| 128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
return text.strip()
|
| 133 |
|
| 134 |
-
@app.
|
| 135 |
-
def generate_svg():
|
| 136 |
-
|
| 137 |
-
prompt
|
| 138 |
-
|
|
|
|
|
|
|
| 139 |
|
| 140 |
negative_prompt = "ไฝๅ่ณชใๆๆชใฎๅ่ณชใไธๆใชๆใๆใ6ๆฌใๆใ4ๆฌใๅฅๅฝขใ้ใใใผใใใฆใใใใผใใใใใฆใฉใผใฟใผใใผใฏใ็ฝฒๅใใใญในใ"
|
| 141 |
try:
|
| 142 |
-
svg_result = pipeline.process(prompt, negative_prompt)
|
| 143 |
|
|
|
|
| 144 |
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
|
| 145 |
-
safe_prompt = sanitize_filename(prompt)[:50]
|
| 146 |
filename = f"{timestamp}_{safe_prompt}.svg"
|
| 147 |
|
| 148 |
svg_path = os.path.join(SVG_DIR, filename)
|
|
@@ -152,17 +163,20 @@ def generate_svg():
|
|
| 152 |
thumbnail_path = os.path.join(THUMBNAIL_DIR, filename.replace('.svg', '.png'))
|
| 153 |
cairosvg.svg2png(bytestring=svg_result.encode('utf-8'), write_to=thumbnail_path, output_width=256, output_height=256)
|
| 154 |
|
| 155 |
-
|
|
|
|
|
|
|
| 156 |
except Exception as e:
|
| 157 |
print(f"An error occurred during generation: {e}")
|
| 158 |
-
|
| 159 |
|
| 160 |
-
@app.route('/gallery', methods=['GET'])
|
| 161 |
-
def get_gallery():
|
| 162 |
-
try:
|
| 163 |
-
page = int(request.args.get('page', 1))
|
| 164 |
-
limit = int(request.args.get('limit', 8))
|
| 165 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
svg_files = sorted([f for f in os.listdir(SVG_DIR) if f.endswith('.svg')], reverse=True)
|
| 167 |
|
| 168 |
start_index = (page - 1) * limit
|
|
@@ -180,31 +194,40 @@ def get_gallery():
|
|
| 180 |
})
|
| 181 |
|
| 182 |
has_more = end_index < len(svg_files)
|
| 183 |
-
return
|
| 184 |
except Exception as e:
|
| 185 |
print(f"Error fetching gallery: {e}")
|
| 186 |
-
|
| 187 |
|
| 188 |
-
@app.route('/svgs/<path:filename>')
|
| 189 |
-
def get_svg(filename):
|
| 190 |
-
return send_from_directory(SVG_DIR, filename)
|
| 191 |
|
| 192 |
-
@app.
|
| 193 |
-
def
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
def delete_drawing_file(filename):
|
| 198 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
svg_path = os.path.join(SVG_DIR, filename)
|
| 200 |
thumb_path = os.path.join(THUMBNAIL_DIR, filename.replace('.svg', '.png'))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
if os.path.exists(svg_path): os.remove(svg_path)
|
| 202 |
if os.path.exists(thumb_path): os.remove(thumb_path)
|
| 203 |
-
|
|
|
|
| 204 |
except Exception as e:
|
| 205 |
print(f"Error deleting file: {e}")
|
| 206 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
|
| 208 |
if __name__ == '__main__':
|
| 209 |
-
print("Starting
|
| 210 |
-
|
|
|
|
| 4 |
import tempfile
|
| 5 |
import cairosvg
|
| 6 |
import re
|
| 7 |
+
import uvicorn
|
| 8 |
from PIL import Image
|
| 9 |
from datetime import datetime
|
| 10 |
+
from fastapi import FastAPI, HTTPException, Request
|
| 11 |
+
from fastapi.responses import FileResponse, JSONResponse, Response
|
| 12 |
+
from fastapi.staticfiles import StaticFiles
|
| 13 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 14 |
+
from pydantic import BaseModel
|
| 15 |
+
from typing import Optional
|
| 16 |
|
| 17 |
from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler
|
|
|
|
| 18 |
import torchvision.transforms as transforms
|
| 19 |
from model import Generator
|
| 20 |
|
| 21 |
+
SVG_DIR = os.path.join(os.getcwd(), 'generated_svgs')
|
| 22 |
+
THUMBNAIL_DIR = os.path.join(os.getcwd(), 'thumbnails')
|
| 23 |
+
SKETCH_MODEL_WEIGHTS = 'checkpoints/netG_A_latest.pth'
|
| 24 |
+
|
| 25 |
def setup_directories():
|
| 26 |
+
"""Creates necessary directories if they don't exist."""
|
| 27 |
os.makedirs(SVG_DIR, exist_ok=True)
|
| 28 |
os.makedirs(THUMBNAIL_DIR, exist_ok=True)
|
| 29 |
print(f"Directories '{SVG_DIR}' and '{THUMBNAIL_DIR}' are ready.")
|
| 30 |
|
| 31 |
+
def sanitize_filename(prompt: str) -> str:
|
| 32 |
"""Removes characters that are invalid for filenames."""
|
|
|
|
| 33 |
s = re.sub(r'[\\/*?:"<>|]', "", prompt)
|
| 34 |
+
return s.strip()[:100]
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
class ImageToSvgPipeline:
|
| 38 |
"""
|
|
|
|
| 49 |
def _initialize_rinna_model(self):
|
| 50 |
print("Loading Rinna Stable Diffusion model...")
|
| 51 |
model_id = "rinna/japanese-stable-diffusion"
|
| 52 |
+
|
| 53 |
self.rinna_pipe = StableDiffusionPipeline.from_pretrained(
|
| 54 |
model_id,
|
| 55 |
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
|
|
|
|
| 104 |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file:
|
| 105 |
image.save(tmp_file.name)
|
| 106 |
tmp_path = tmp_file.name
|
| 107 |
+
|
| 108 |
+
svg_output_path = tmp_path.replace(".png", ".svg")
|
| 109 |
try:
|
|
|
|
| 110 |
vtracer.convert_image_to_svg_py(tmp_path, svg_output_path)
|
|
|
|
| 111 |
with open(svg_output_path, 'r', encoding='utf-8') as f:
|
| 112 |
svg_data = f.read()
|
| 113 |
finally:
|
| 114 |
if os.path.exists(tmp_path): os.remove(tmp_path)
|
| 115 |
+
if os.path.exists(svg_output_path): os.remove(svg_output_path)
|
| 116 |
+
|
| 117 |
print("SVG extraction complete.")
|
| 118 |
return svg_data
|
| 119 |
|
|
|
|
| 123 |
svg_content = self._extract_svg(sketch_image)
|
| 124 |
return svg_content
|
| 125 |
|
| 126 |
+
setup_directories()
|
| 127 |
+
pipeline = ImageToSvgPipeline(sketch_model_path=SKETCH_MODEL_WEIGHTS)
|
| 128 |
|
| 129 |
+
app = FastAPI()
|
| 130 |
|
| 131 |
+
app.add_middleware(
|
| 132 |
+
CORSMiddleware,
|
| 133 |
+
allow_origins=["*"], # Allows all origins
|
| 134 |
+
allow_credentials=True,
|
| 135 |
+
allow_methods=["*"], # Allows all methods
|
| 136 |
+
allow_headers=["*"], # Allows all headers
|
| 137 |
+
)
|
| 138 |
|
| 139 |
+
class GenerateRequest(BaseModel):
|
| 140 |
+
prompt: str
|
|
|
|
| 141 |
|
| 142 |
+
@app.post("/generate")
|
| 143 |
+
async def generate_svg(item: GenerateRequest):
|
| 144 |
+
"""
|
| 145 |
+
Receives a prompt, generates an SVG, saves it, and returns the SVG content.
|
| 146 |
+
"""
|
| 147 |
+
if not item.prompt:
|
| 148 |
+
raise HTTPException(status_code=400, detail="Prompt is required")
|
| 149 |
|
| 150 |
negative_prompt = "ไฝๅ่ณชใๆๆชใฎๅ่ณชใไธๆใชๆใๆใ6ๆฌใๆใ4ๆฌใๅฅๅฝขใ้ใใใผใใใฆใใใใผใใใใใฆใฉใผใฟใผใใผใฏใ็ฝฒๅใใใญในใ"
|
| 151 |
try:
|
| 152 |
+
svg_result = pipeline.process(item.prompt, negative_prompt)
|
| 153 |
|
| 154 |
+
# Save the SVG and its thumbnail
|
| 155 |
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
|
| 156 |
+
safe_prompt = sanitize_filename(item.prompt)[:50]
|
| 157 |
filename = f"{timestamp}_{safe_prompt}.svg"
|
| 158 |
|
| 159 |
svg_path = os.path.join(SVG_DIR, filename)
|
|
|
|
| 163 |
thumbnail_path = os.path.join(THUMBNAIL_DIR, filename.replace('.svg', '.png'))
|
| 164 |
cairosvg.svg2png(bytestring=svg_result.encode('utf-8'), write_to=thumbnail_path, output_width=256, output_height=256)
|
| 165 |
|
| 166 |
+
# Return the SVG data directly in the response
|
| 167 |
+
return Response(content=svg_result, media_type="image/svg+xml")
|
| 168 |
+
|
| 169 |
except Exception as e:
|
| 170 |
print(f"An error occurred during generation: {e}")
|
| 171 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 172 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
|
| 174 |
+
@app.get("/gallery")
|
| 175 |
+
def get_gallery(page: int = 1, limit: int = 8):
|
| 176 |
+
"""
|
| 177 |
+
Returns a paginated list of generated drawings.
|
| 178 |
+
"""
|
| 179 |
+
try:
|
| 180 |
svg_files = sorted([f for f in os.listdir(SVG_DIR) if f.endswith('.svg')], reverse=True)
|
| 181 |
|
| 182 |
start_index = (page - 1) * limit
|
|
|
|
| 194 |
})
|
| 195 |
|
| 196 |
has_more = end_index < len(svg_files)
|
| 197 |
+
return {"drawings": drawings, "hasMore": has_more}
|
| 198 |
except Exception as e:
|
| 199 |
print(f"Error fetching gallery: {e}")
|
| 200 |
+
raise HTTPException(status_code=500, detail="Failed to fetch gallery")
|
| 201 |
|
|
|
|
|
|
|
|
|
|
| 202 |
|
| 203 |
+
@app.delete("/drawings/{filename}")
|
| 204 |
+
def delete_drawing_file(filename: str):
|
| 205 |
+
"""
|
| 206 |
+
Deletes a specific SVG and its corresponding thumbnail.
|
| 207 |
+
"""
|
|
|
|
| 208 |
try:
|
| 209 |
+
# Sanitize filename to prevent directory traversal
|
| 210 |
+
if ".." in filename or "/" in filename:
|
| 211 |
+
raise HTTPException(status_code=400, detail="Invalid filename")
|
| 212 |
+
|
| 213 |
svg_path = os.path.join(SVG_DIR, filename)
|
| 214 |
thumb_path = os.path.join(THUMBNAIL_DIR, filename.replace('.svg', '.png'))
|
| 215 |
+
|
| 216 |
+
if not os.path.exists(svg_path):
|
| 217 |
+
raise HTTPException(status_code=404, detail="File not found")
|
| 218 |
+
|
| 219 |
if os.path.exists(svg_path): os.remove(svg_path)
|
| 220 |
if os.path.exists(thumb_path): os.remove(thumb_path)
|
| 221 |
+
|
| 222 |
+
return JSONResponse(content={"message": f"Successfully deleted {filename}"})
|
| 223 |
except Exception as e:
|
| 224 |
print(f"Error deleting file: {e}")
|
| 225 |
+
raise HTTPException(status_code=500, detail="Failed to delete file")
|
| 226 |
+
|
| 227 |
+
app.mount("/svgs", StaticFiles(directory=SVG_DIR), name="svgs")
|
| 228 |
+
app.mount("/thumbnails", StaticFiles(directory=THUMBNAIL_DIR), name="thumbnails")
|
| 229 |
+
|
| 230 |
|
| 231 |
if __name__ == '__main__':
|
| 232 |
+
print("Starting FastAPI server...")
|
| 233 |
+
uvicorn.run(app, host='0.0.0.0', port=5000)
|
requirements.txt
CHANGED
|
@@ -1,5 +1,7 @@
|
|
| 1 |
flask
|
| 2 |
Flask-Cors
|
|
|
|
|
|
|
| 3 |
torch
|
| 4 |
diffusers==0.35.1
|
| 5 |
transformers==4.56.2
|
|
@@ -11,3 +13,4 @@ torchvision==0.23.0
|
|
| 11 |
sentencepiece==0.2.0
|
| 12 |
scipy
|
| 13 |
numpy
|
|
|
|
|
|
| 1 |
flask
|
| 2 |
Flask-Cors
|
| 3 |
+
fastapi
|
| 4 |
+
uvicorn[standard]
|
| 5 |
torch
|
| 6 |
diffusers==0.35.1
|
| 7 |
transformers==4.56.2
|
|
|
|
| 13 |
sentencepiece==0.2.0
|
| 14 |
scipy
|
| 15 |
numpy
|
| 16 |
+
python-multipart
|