Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import vtracer | |
| import tempfile | |
| import cairosvg | |
| import re | |
| from PIL import Image | |
| from datetime import datetime | |
| import gc | |
| import json | |
| import time | |
| import queue | |
| import threading | |
| from flask import Flask, request, jsonify, send_from_directory, Response, stream_with_context | |
| from flask_cors import CORS | |
| from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler | |
| from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput | |
| import torchvision.transforms as transforms | |
| from model import Generator | |
| from utils import process_svg | |
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" | |
| def setup_directories(): | |
| os.makedirs(STROKES_DIR, exist_ok=True) | |
| os.makedirs(THUMBNAIL_DIR, exist_ok=True) | |
| print(f"Directories '{STROKES_DIR}' and '{THUMBNAIL_DIR}' are ready.") | |
| def sanitize_filename(prompt): | |
| """Removes characters that are invalid for filenames.""" | |
| s = re.sub(r'[\\/*?:"<>|]', "", prompt) | |
| return s[:100] | |
| STROKES_DIR = os.path.join(os.getcwd(), 'strokes') | |
| THUMBNAIL_DIR = os.path.join(os.getcwd(), 'thumbnails') | |
| SKETCH_MODEL_WEIGHTS = os.path.join('checkpoints', 'netG_A_latest.pth') | |
| class ImageToSvgPipeline: | |
| def __init__(self, sketch_model_path: str): | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {self.device}") | |
| self._initialize_rinna_model() | |
| self._initialize_sketch_model(sketch_model_path) | |
| def _initialize_rinna_model(self): | |
| print("Loading Rinna Stable Diffusion model...") | |
| model_id = "rinna/japanese-stable-diffusion" | |
| self.rinna_pipe = StableDiffusionPipeline.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, | |
| ) | |
| self.rinna_pipe.scheduler = LMSDiscreteScheduler( | |
| beta_start=0.00085, beta_end=0.012, | |
| beta_schedule="scaled_linear", num_train_timesteps=1000 | |
| ) | |
| self.rinna_pipe.tokenizer.model_max_length = 77 | |
| self.rinna_pipe.to(self.device) | |
| self.rinna_pipe.set_progress_bar_config(disable=True) | |
| print("Rinna model loaded.") | |
| def unload_rinna_model(self): | |
| if hasattr(self, 'rinna_pipe'): | |
| print("Unloading Rinna Stable Diffusion model...") | |
| del self.rinna_pipe | |
| gc.collect() | |
| if self.device == "cuda": | |
| torch.cuda.empty_cache() | |
| print("GPU memory cache cleared.") | |
| print("Rinna model unloaded successfully.") | |
| else: | |
| print("Rinna model is not currently loaded.") | |
| def _initialize_sketch_model(self, model_path: str): | |
| print(f"Loading Sketch Generator model from {model_path}...") | |
| if not os.path.exists(model_path): | |
| raise FileNotFoundError(f"Sketch model weights not found at: {model_path}") | |
| self.sketch_model = Generator(input_nc=3, output_nc=1, n_residual_blocks=3) | |
| self.sketch_model.to(self.device) | |
| self.sketch_model.load_state_dict(torch.load(model_path, map_location=self.device)) | |
| self.sketch_model.eval() | |
| self.sketch_transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| ]) | |
| print("Sketch model loaded.") | |
| def _generate_image(self, prompt: str, negative_prompt: str, steps: int = 30, callback=None) -> Image.Image: | |
| print(f"Generating image for prompt: '{prompt}'") | |
| with torch.no_grad(): | |
| output: StableDiffusionPipelineOutput = self.rinna_pipe( | |
| prompt, | |
| negative_prompt=negative_prompt, | |
| num_inference_steps=steps, | |
| guidance_scale=7.5, | |
| width=720, | |
| height=720, | |
| callback_on_step_end=callback | |
| ) | |
| return output.images[0] | |
| def _convert_to_sketch(self, image: Image.Image) -> Image.Image: | |
| print("Converting image to sketch...") | |
| with torch.no_grad(): | |
| input_tensor = self.sketch_transform(image.convert("RGB")).unsqueeze(0).to(self.device) | |
| output_tensor = self.sketch_model(input_tensor) | |
| output_tensor = output_tensor.squeeze(0).cpu() | |
| sketch_image = transforms.ToPILImage()(output_tensor) | |
| return sketch_image | |
| def _extract_svg(self, image: Image.Image) -> str: | |
| print("Extracting SVG from sketch...") | |
| with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file: | |
| image.save(tmp_file.name) | |
| tmp_path = tmp_file.name | |
| try: | |
| svg_output_path = tmp_path.replace(".png", ".svg") | |
| vtracer.convert_image_to_svg_py(tmp_path, svg_output_path) | |
| with open(svg_output_path, 'r', encoding='utf-8') as f: | |
| svg_data = f.read() | |
| finally: | |
| if os.path.exists(tmp_path): os.remove(tmp_path) | |
| if 'svg_output_path' in locals() and os.path.exists(svg_output_path): os.remove(svg_output_path) | |
| print("SVG extraction complete.") | |
| return svg_data | |
| def process(self, prompt: str, img_path: str, negative_prompt: str, callback=None): | |
| """Processes the image generation and conversion, with progress callbacks.""" | |
| def _callback(progress, step_name): | |
| if callback: | |
| callback(progress, step_name) | |
| generated_img = None | |
| if img_path is None: | |
| total_diffusion_steps = 30 | |
| def diffusion_callback(pipe, step_index, timestep, callback_kwargs): | |
| progress = int(5 + ((step_index + 1) / total_diffusion_steps) * 75) | |
| _callback(progress, "Generating image...") | |
| return callback_kwargs | |
| _callback(5, "Starting image generation...") | |
| generated_img = self._generate_image( | |
| prompt, | |
| negative_prompt, | |
| steps=total_diffusion_steps, | |
| callback=diffusion_callback | |
| ) | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| _callback(80, "Base image generated.") | |
| img_to_process = generated_img | |
| else: | |
| generated_img = Image.open(img_path) | |
| img_to_process = generated_img | |
| _callback(80, "Image loaded.") | |
| _callback(85, "Converting to sketch...") | |
| sketch_image = self._convert_to_sketch(img_to_process) | |
| _callback(90, "Vectorizing sketch...") | |
| svg_content = self._extract_svg(sketch_image) | |
| _callback(95, "SVG extracted.") | |
| return svg_content, generated_img | |
| app = Flask(__name__) | |
| CORS(app, resources={r"/*": {"origins": "*"}}) | |
| pipeline = ImageToSvgPipeline(sketch_model_path=SKETCH_MODEL_WEIGHTS) | |
| def add_ngrok_header(response): | |
| response.headers['ngrok-skip-browser-warning'] = 'true' | |
| return response | |
| def generate_stroke(): | |
| prompt = request.args.get('prompt') | |
| if not prompt: | |
| return jsonify({"error": "Prompt is required"}), 400 | |
| negative_prompt = ( | |
| "低品質、最悪の品質、奇形、醜い、ぼやけている、ぼやけた、" | |
| "ウォーターマーク、署名、テキスト、フレームから外れた、" | |
| "手足が切れている、クロップされた、被写体が切り取られている、" | |
| "構成が悪い、焦点が合っていない" | |
| ) | |
| q = queue.Queue() | |
| def worker(): | |
| """Runs the long-running task in a separate thread and puts progress into the queue.""" | |
| start_time = time.time() | |
| def progress_callback(progress, step): | |
| print(f"Progress: {progress}% - {step}") | |
| data = json.dumps({"progress": progress, "step": step}) | |
| q.put(data) | |
| try: | |
| progress_callback(5, "Initializing...") | |
| svg_result, generated_image = pipeline.process(prompt, None, negative_prompt, callback=progress_callback) | |
| progress_callback(98, "Finalizing and saving...") | |
| timestamp = datetime.now().strftime("%Y%m%d%H%M%S") | |
| safe_prompt = sanitize_filename(prompt)[:60] | |
| filename_base = f"{timestamp}_{safe_prompt}" | |
| stroke_path = os.path.join(STROKES_DIR, f"{filename_base}.json") | |
| stroke = process_svg(svg_result, "file") | |
| with open(stroke_path, 'w', encoding='utf-8') as f: | |
| json.dump(stroke, f, ensure_ascii=False, indent=2) | |
| if generated_image: | |
| thumbnail_path = os.path.join(THUMBNAIL_DIR, f"{filename_base}.png") | |
| cairosvg.svg2png(bytestring=svg_result.encode('utf-8'), write_to=thumbnail_path, output_width=256, output_height=256) | |
| final_data = json.dumps({"progress": 100, "result": stroke, "step": "Complete!"}) | |
| q.put(final_data) | |
| end_time = time.time() | |
| print(f"Total generation time: {end_time - start_time:.2f} seconds") | |
| except Exception as e: | |
| print(f"Error during generation stream: {e}") | |
| error_data = json.dumps({"error": str(e), "progress": 100}) | |
| q.put(error_data) | |
| finally: | |
| q.put(None) | |
| threading.Thread(target=worker).start() | |
| def generate(): | |
| """This generator reads from the queue and yields data to the client.""" | |
| while True: | |
| item = q.get() | |
| if item is None: | |
| break | |
| yield f"data: {item}\n\n" | |
| return Response(stream_with_context(generate()), mimetype='text/event-stream') | |
| def get_gallery(): | |
| try: | |
| page = int(request.args.get('page', 1)) | |
| limit = int(request.args.get('limit', 8)) | |
| strokes_files = sorted([f for f in os.listdir(STROKES_DIR) if f.endswith('.json')], reverse=True) | |
| start_index = (page - 1) * limit | |
| end_index = start_index + limit | |
| paginated_files = strokes_files[start_index:end_index] | |
| drawings = [] | |
| for filename in paginated_files: | |
| prompt_match = re.match(r"\d+_(.+)\.json", filename) | |
| prompt = prompt_match.group(1).replace('_', ' ') if prompt_match else "Prompt not found" | |
| drawings.append({ | |
| "filename": filename, | |
| "thumbnail": f"/thumbnails/{filename.replace('.json', '.png')}", | |
| "prompt": prompt | |
| }) | |
| has_more = end_index < len(strokes_files) | |
| return jsonify({"drawings": drawings, "hasMore": has_more}) | |
| except Exception as e: | |
| print(f"Error fetching gallery: {e}") | |
| return jsonify({"error": "Failed to fetch gallery"}), 500 | |
| def add_svg(): | |
| data = request.json | |
| folder_path = data.get('folderPath').strip() | |
| count = 0 | |
| for file in os.listdir(folder_path): | |
| file_path = os.path.join(folder_path, file) | |
| stroke_path = os.path.join(STROKES_DIR, file.replace('.svg', '.json')) | |
| stroke = process_svg(file_path, "path") | |
| with open(stroke_path, 'w', encoding='utf-8') as f: | |
| json.dump(stroke, f, ensure_ascii=False, indent=2) | |
| thumbnail_path = os.path.join(THUMBNAIL_DIR, file.replace('.svg', '.png')) | |
| cairosvg.svg2png(url=file_path, write_to=thumbnail_path, output_width=256, output_height=256) | |
| count += 1 | |
| return jsonify({"status": "success", "message": f"Processed {count} SVG files."}) | |
| def add_img(): | |
| data = request.json | |
| folder_path = data.get('folderPath').strip() | |
| count = 0 | |
| pipeline.unload_rinna_model() | |
| for file in os.listdir(folder_path): | |
| file_path = os.path.join(folder_path, file) | |
| svg_result, _ = pipeline.process(None, file_path, None) | |
| timestamp = datetime.now().strftime("%Y%m%d%H%M%S") | |
| filename = f"{timestamp}_{file.replace('.jpg', '.json').replace('.png', '.json')}" | |
| stroke_path = os.path.join(STROKES_DIR, filename) | |
| stroke = process_svg(svg_result, "file") | |
| with open(stroke_path, 'w', encoding='utf-8') as f: | |
| json.dump(stroke, f, ensure_ascii=False, indent=2) | |
| thumbnail_path = os.path.join(THUMBNAIL_DIR, filename.replace('.json', '.png')) | |
| cairosvg.svg2png(bytestring=svg_result.encode('utf-8'), write_to=thumbnail_path, output_width=256, output_height=256) | |
| count += 1 | |
| pipeline._initialize_rinna_model() | |
| return jsonify({"status": "success", "message": f"Processed {count} image files."}) | |
| def get_strokes(filename): | |
| return send_from_directory(STROKES_DIR, filename) | |
| def get_thumbnail(filename): | |
| return send_from_directory(THUMBNAIL_DIR, filename) | |
| def delete_drawing_file(filename): | |
| try: | |
| json_path = os.path.join(STROKES_DIR, filename) | |
| thumb_path = os.path.join(THUMBNAIL_DIR, filename.replace('.json', '.png')) | |
| if os.path.exists(json_path): os.remove(json_path) | |
| if os.path.exists(thumb_path): os.remove(thumb_path) | |
| return jsonify({"message": f"Successfully deleted {filename}"}) | |
| except Exception as e: | |
| print(f"Error deleting file: {e}") | |
| return jsonify({"error": "Failed to delete file"}), 500 | |
| app.mount("/strokes", StaticFiles(directory=STROKES_DIR), name="strokes") | |
| app.mount("/thumbnails", StaticFiles(directory=THUMBNAIL_DIR), name="thumbnails") | |
| if __name__ == '__main__': | |
| print("Starting FastAPI server...") | |
| uvicorn.run(app, host='0.0.0.0', port=7860) | |