Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import vtracer | |
| import tempfile | |
| import cairosvg | |
| import re | |
| from PIL import Image | |
| from datetime import datetime | |
| from flask import Flask, request, jsonify, send_from_directory | |
| from flask_cors import CORS | |
| from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler | |
| import torchvision.transforms as transforms | |
| from model import Generator | |
| def setup_directories(): | |
| os.makedirs(SVG_DIR, exist_ok=True) | |
| os.makedirs(THUMBNAIL_DIR, exist_ok=True) | |
| print(f"Directories '{SVG_DIR}' and '{THUMBNAIL_DIR}' are ready.") | |
| def sanitize_filename(prompt): | |
| """Removes characters that are invalid for filenames.""" | |
| s = re.sub(r'[\\/*?:"<>|]', "", prompt) | |
| return s[:100] | |
| SVG_DIR = os.path.join(os.getcwd(), 'generated_svgs') | |
| THUMBNAIL_DIR = os.path.join(os.getcwd(), 'thumbnails') | |
| SKETCH_MODEL_WEIGHTS = 'checkpoints/netG_A_latest.pth' | |
| class ImageToSvgPipeline: | |
| """ | |
| A class to handle the entire pipeline from text prompt to SVG. | |
| Initializes models once to be reused. | |
| """ | |
| def __init__(self, sketch_model_path: str): | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {self.device}") | |
| self._initialize_rinna_model() | |
| self._initialize_sketch_model(sketch_model_path) | |
| def _initialize_rinna_model(self): | |
| print("Loading Rinna Stable Diffusion model...") | |
| model_id = "rinna/japanese-stable-diffusion" | |
| self.rinna_pipe = StableDiffusionPipeline.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, | |
| ) | |
| self.rinna_pipe.scheduler = LMSDiscreteScheduler( | |
| beta_start=0.00085, beta_end=0.012, | |
| beta_schedule="scaled_linear", num_train_timesteps=1000 | |
| ) | |
| self.rinna_pipe.tokenizer.model_max_length = 77 | |
| self.rinna_pipe.to(self.device) | |
| print("Rinna model loaded.") | |
| def _initialize_sketch_model(self, model_path: str): | |
| print(f"Loading Sketch Generator model from {model_path}...") | |
| if not os.path.exists(model_path): | |
| raise FileNotFoundError(f"Sketch model weights not found at: {model_path}") | |
| self.sketch_model = Generator(input_nc=3, output_nc=1, n_residual_blocks=3) | |
| self.sketch_model.to(self.device) | |
| self.sketch_model.load_state_dict(torch.load(model_path, map_location=self.device)) | |
| self.sketch_model.eval() | |
| self.sketch_transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| ]) | |
| print("Sketch model loaded.") | |
| def _generate_image(self, prompt: str, negative_prompt: str, steps: int = 30) -> Image.Image: | |
| print(f"Generating image for prompt: '{prompt}'") | |
| with torch.no_grad(): | |
| image = self.rinna_pipe( | |
| prompt, | |
| negative_prompt=negative_prompt, | |
| num_inference_steps=steps, | |
| guidance_scale=7.5, | |
| width=512, | |
| height=512, | |
| ).images[0] | |
| return image | |
| def _convert_to_sketch(self, image: Image.Image) -> Image.Image: | |
| print("Converting image to sketch...") | |
| with torch.no_grad(): | |
| input_tensor = self.sketch_transform(image.convert("RGB")).unsqueeze(0).to(self.device) | |
| output_tensor = self.sketch_model(input_tensor) | |
| output_tensor = output_tensor.squeeze(0).cpu() | |
| sketch_image = transforms.ToPILImage()(output_tensor) | |
| return sketch_image | |
| def _extract_svg(self, image: Image.Image) -> str: | |
| print("Extracting SVG from sketch...") | |
| with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file: | |
| image.save(tmp_file.name) | |
| tmp_path = tmp_file.name | |
| try: | |
| svg_output_path = tmp_path.replace(".png", ".svg") | |
| vtracer.convert_image_to_svg_py(tmp_path, svg_output_path) | |
| with open(svg_output_path, 'r', encoding='utf-8') as f: | |
| svg_data = f.read() | |
| finally: | |
| if os.path.exists(tmp_path): os.remove(tmp_path) | |
| if 'svg_output_path' in locals() and os.path.exists(svg_output_path): os.remove(svg_output_path) | |
| print("SVG extraction complete.") | |
| return svg_data | |
| def process(self, prompt: str, negative_prompt: str) -> str: | |
| generated_image = self._generate_image(prompt, negative_prompt) | |
| sketch_image = self._convert_to_sketch(generated_image) | |
| svg_content = self._extract_svg(sketch_image) | |
| return svg_content | |
| app = Flask(__name__) | |
| CORS(app, resources={r"/*": {"origins": "*"}}) | |
| pipeline = ImageToSvgPipeline(sketch_model_path=SKETCH_MODEL_WEIGHTS) | |
| def sanitize_filename(text): | |
| text = re.sub(r'[\\/*?:"<>|]', "", text) | |
| return text.strip() | |
| def generate_svg(): | |
| data = request.json | |
| prompt = data.get('prompt') | |
| if not prompt: return jsonify({"error": "Prompt is required"}), 400 | |
| negative_prompt = "ไฝๅ่ณชใๆๆชใฎๅ่ณชใไธๆใชๆใๆใ6ๆฌใๆใ4ๆฌใๅฅๅฝขใ้ใใใผใใใฆใใใใผใใใใใฆใฉใผใฟใผใใผใฏใ็ฝฒๅใใใญในใ" | |
| try: | |
| svg_result = pipeline.process(prompt, negative_prompt) | |
| timestamp = datetime.now().strftime("%Y%m%d%H%M%S") | |
| safe_prompt = sanitize_filename(prompt)[:50] | |
| filename = f"{timestamp}_{safe_prompt}.svg" | |
| svg_path = os.path.join(SVG_DIR, filename) | |
| with open(svg_path, 'w', encoding='utf-8') as f: | |
| f.write(svg_result) | |
| thumbnail_path = os.path.join(THUMBNAIL_DIR, filename.replace('.svg', '.png')) | |
| cairosvg.svg2png(bytestring=svg_result.encode('utf-8'), write_to=thumbnail_path, output_width=256, output_height=256) | |
| return svg_result, 200, {'Content-Type': 'image/svg+xml'} | |
| except Exception as e: | |
| print(f"An error occurred during generation: {e}") | |
| return jsonify({"error": str(e)}), 500 | |
| def get_gallery(): | |
| try: | |
| page = int(request.args.get('page', 1)) | |
| limit = int(request.args.get('limit', 8)) | |
| svg_files = sorted([f for f in os.listdir(SVG_DIR) if f.endswith('.svg')], reverse=True) | |
| start_index = (page - 1) * limit | |
| end_index = start_index + limit | |
| paginated_files = svg_files[start_index:end_index] | |
| drawings = [] | |
| for filename in paginated_files: | |
| prompt_match = re.match(r"\d+_(.+)\.svg", filename) | |
| prompt = prompt_match.group(1).replace('_', ' ') if prompt_match else "Prompt not found" | |
| drawings.append({ | |
| "filename": filename, | |
| "thumbnail": f"/thumbnails/{filename.replace('.svg', '.png')}", | |
| "prompt": prompt | |
| }) | |
| has_more = end_index < len(svg_files) | |
| return jsonify({"drawings": drawings, "hasMore": has_more}) | |
| except Exception as e: | |
| print(f"Error fetching gallery: {e}") | |
| return jsonify({"error": "Failed to fetch gallery"}), 500 | |
| def get_svg(filename): | |
| return send_from_directory(SVG_DIR, filename) | |
| def get_thumbnail(filename): | |
| return send_from_directory(THUMBNAIL_DIR, filename) | |
| def delete_drawing_file(filename): | |
| try: | |
| svg_path = os.path.join(SVG_DIR, filename) | |
| thumb_path = os.path.join(THUMBNAIL_DIR, filename.replace('.svg', '.png')) | |
| if os.path.exists(svg_path): os.remove(svg_path) | |
| if os.path.exists(thumb_path): os.remove(thumb_path) | |
| return jsonify({"message": f"Successfully deleted {filename}"}) | |
| except Exception as e: | |
| print(f"Error deleting file: {e}") | |
| return jsonify({"error": "Failed to delete file"}), 500 | |
| if __name__ == '__main__': | |
| print("Starting Flask server...") | |
| app.run(host='0.0.0.0', port=5000) |