import os import torch import vtracer import tempfile import cairosvg import re from PIL import Image from datetime import datetime from flask import Flask, request, jsonify, send_from_directory from flask_cors import CORS from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler import torchvision.transforms as transforms from model import Generator def setup_directories(): os.makedirs(SVG_DIR, exist_ok=True) os.makedirs(THUMBNAIL_DIR, exist_ok=True) print(f"Directories '{SVG_DIR}' and '{THUMBNAIL_DIR}' are ready.") def sanitize_filename(prompt): """Removes characters that are invalid for filenames.""" s = re.sub(r'[\\/*?:"<>|]', "", prompt) return s[:100] SVG_DIR = os.path.join(os.getcwd(), 'generated_svgs') THUMBNAIL_DIR = os.path.join(os.getcwd(), 'thumbnails') SKETCH_MODEL_WEIGHTS = 'checkpoints/netG_A_latest.pth' class ImageToSvgPipeline: """ A class to handle the entire pipeline from text prompt to SVG. Initializes models once to be reused. """ def __init__(self, sketch_model_path: str): self.device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {self.device}") self._initialize_rinna_model() self._initialize_sketch_model(sketch_model_path) def _initialize_rinna_model(self): print("Loading Rinna Stable Diffusion model...") model_id = "rinna/japanese-stable-diffusion" self.rinna_pipe = StableDiffusionPipeline.from_pretrained( model_id, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, ) self.rinna_pipe.scheduler = LMSDiscreteScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000 ) self.rinna_pipe.tokenizer.model_max_length = 77 self.rinna_pipe.to(self.device) print("Rinna model loaded.") def _initialize_sketch_model(self, model_path: str): print(f"Loading Sketch Generator model from {model_path}...") if not os.path.exists(model_path): raise FileNotFoundError(f"Sketch model weights not found at: {model_path}") self.sketch_model = Generator(input_nc=3, output_nc=1, n_residual_blocks=3) self.sketch_model.to(self.device) self.sketch_model.load_state_dict(torch.load(model_path, map_location=self.device)) self.sketch_model.eval() self.sketch_transform = transforms.Compose([ transforms.ToTensor(), ]) print("Sketch model loaded.") def _generate_image(self, prompt: str, negative_prompt: str, steps: int = 30) -> Image.Image: print(f"Generating image for prompt: '{prompt}'") with torch.no_grad(): image = self.rinna_pipe( prompt, negative_prompt=negative_prompt, num_inference_steps=steps, guidance_scale=7.5, width=512, height=512, ).images[0] return image def _convert_to_sketch(self, image: Image.Image) -> Image.Image: print("Converting image to sketch...") with torch.no_grad(): input_tensor = self.sketch_transform(image.convert("RGB")).unsqueeze(0).to(self.device) output_tensor = self.sketch_model(input_tensor) output_tensor = output_tensor.squeeze(0).cpu() sketch_image = transforms.ToPILImage()(output_tensor) return sketch_image def _extract_svg(self, image: Image.Image) -> str: print("Extracting SVG from sketch...") with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file: image.save(tmp_file.name) tmp_path = tmp_file.name try: svg_output_path = tmp_path.replace(".png", ".svg") vtracer.convert_image_to_svg_py(tmp_path, svg_output_path) with open(svg_output_path, 'r', encoding='utf-8') as f: svg_data = f.read() finally: if os.path.exists(tmp_path): os.remove(tmp_path) if 'svg_output_path' in locals() and os.path.exists(svg_output_path): os.remove(svg_output_path) print("SVG extraction complete.") return svg_data def process(self, prompt: str, negative_prompt: str) -> str: generated_image = self._generate_image(prompt, negative_prompt) sketch_image = self._convert_to_sketch(generated_image) svg_content = self._extract_svg(sketch_image) return svg_content app = Flask(__name__) CORS(app, resources={r"/*": {"origins": "*"}}) pipeline = ImageToSvgPipeline(sketch_model_path=SKETCH_MODEL_WEIGHTS) def sanitize_filename(text): text = re.sub(r'[\\/*?:"<>|]', "", text) return text.strip() @app.route('/generate', methods=['POST']) def generate_svg(): data = request.json prompt = data.get('prompt') if not prompt: return jsonify({"error": "Prompt is required"}), 400 negative_prompt = "低品質、最悪の品質、下手な手、指が6本、指が4本、奇形、醜い、ぼやけている、ぼやけた、ウォーターマーク、署名、テキスト" try: svg_result = pipeline.process(prompt, negative_prompt) timestamp = datetime.now().strftime("%Y%m%d%H%M%S") safe_prompt = sanitize_filename(prompt)[:50] filename = f"{timestamp}_{safe_prompt}.svg" svg_path = os.path.join(SVG_DIR, filename) with open(svg_path, 'w', encoding='utf-8') as f: f.write(svg_result) thumbnail_path = os.path.join(THUMBNAIL_DIR, filename.replace('.svg', '.png')) cairosvg.svg2png(bytestring=svg_result.encode('utf-8'), write_to=thumbnail_path, output_width=256, output_height=256) return svg_result, 200, {'Content-Type': 'image/svg+xml'} except Exception as e: print(f"An error occurred during generation: {e}") return jsonify({"error": str(e)}), 500 @app.route('/gallery', methods=['GET']) def get_gallery(): try: page = int(request.args.get('page', 1)) limit = int(request.args.get('limit', 8)) svg_files = sorted([f for f in os.listdir(SVG_DIR) if f.endswith('.svg')], reverse=True) start_index = (page - 1) * limit end_index = start_index + limit paginated_files = svg_files[start_index:end_index] drawings = [] for filename in paginated_files: prompt_match = re.match(r"\d+_(.+)\.svg", filename) prompt = prompt_match.group(1).replace('_', ' ') if prompt_match else "Prompt not found" drawings.append({ "filename": filename, "thumbnail": f"/thumbnails/{filename.replace('.svg', '.png')}", "prompt": prompt }) has_more = end_index < len(svg_files) return jsonify({"drawings": drawings, "hasMore": has_more}) except Exception as e: print(f"Error fetching gallery: {e}") return jsonify({"error": "Failed to fetch gallery"}), 500 @app.route('/svgs/') def get_svg(filename): return send_from_directory(SVG_DIR, filename) @app.route('/thumbnails/') def get_thumbnail(filename): return send_from_directory(THUMBNAIL_DIR, filename) @app.route('/drawings/', methods=['DELETE']) def delete_drawing_file(filename): try: svg_path = os.path.join(SVG_DIR, filename) thumb_path = os.path.join(THUMBNAIL_DIR, filename.replace('.svg', '.png')) if os.path.exists(svg_path): os.remove(svg_path) if os.path.exists(thumb_path): os.remove(thumb_path) return jsonify({"message": f"Successfully deleted {filename}"}) except Exception as e: print(f"Error deleting file: {e}") return jsonify({"error": "Failed to delete file"}), 500 if __name__ == '__main__': print("Starting Flask server...") app.run(host='0.0.0.0', port=5000)