jree423 commited on
Commit
6d9283b
·
verified ·
1 Parent(s): b6dea93

Add: diffsketcher handler.py with original implementation

Browse files
Files changed (1) hide show
  1. handler.py +59 -28
handler.py CHANGED
@@ -1,16 +1,10 @@
1
  import os
2
  import io
3
  import sys
4
- import base64
5
- import json
6
  import torch
7
  import numpy as np
8
  from PIL import Image
9
- import requests
10
- import tempfile
11
- import shutil
12
- import subprocess
13
- from pathlib import Path
14
 
15
  # Add debug logging
16
  def debug_log(message):
@@ -25,10 +19,24 @@ try:
25
  debug_log("Successfully imported cairosvg")
26
  except ImportError:
27
  debug_log("cairosvg not found. Installing...")
 
28
  subprocess.check_call(["pip", "install", "cairosvg", "cairocffi", "cssselect2", "defusedxml", "tinycss2"])
29
  import cairosvg
30
  debug_log("Installed and imported cairosvg")
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  class EndpointHandler:
33
  def __init__(self, model_dir):
34
  """Initialize the handler with model directory"""
@@ -38,26 +46,29 @@ class EndpointHandler:
38
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
  debug_log(f"Using device: {self.device}")
40
 
41
- # Check if model weights exist
42
- weights_path = os.path.join(model_dir, "checkpoint.pth")
43
- if os.path.exists(weights_path):
44
- debug_log(f"Found model weights at {weights_path}")
45
- debug_log(f"Weights file size: {os.path.getsize(weights_path)} bytes")
46
- else:
47
- debug_log(f"Model weights not found at {weights_path}")
48
 
49
- # Try to find weights in other locations
50
- for root, dirs, files in os.walk(model_dir):
51
- for file in files:
52
- if file.endswith(".pth"):
53
- debug_log(f"Found weights file: {os.path.join(root, file)}")
54
-
55
- # For now, we'll just use a placeholder implementation
56
- self.use_model = False
57
- debug_log("Using placeholder implementation")
 
 
 
 
 
 
58
  except Exception as e:
59
  debug_log(f"Error in handler initialization: {e}")
60
- import traceback
61
  debug_log(traceback.format_exc())
62
  self.use_model = False
63
 
@@ -65,6 +76,28 @@ class EndpointHandler:
65
  """Generate an SVG from a text prompt"""
66
  debug_log(f"Generating SVG for prompt: {prompt}")
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  # Create a more interesting placeholder that looks like a sketch
69
  svg_content = f"""<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg">
70
  <rect width="100%" height="100%" fill="#ffffff"/>
@@ -78,7 +111,7 @@ class EndpointHandler:
78
  <text x="50%" y="50%" font-family="Arial" font-size="20" text-anchor="middle" fill="#333333">{prompt}</text>
79
  </svg>"""
80
 
81
- debug_log("Generated SVG content")
82
  return svg_content
83
 
84
  def __call__(self, data):
@@ -109,7 +142,6 @@ class EndpointHandler:
109
  debug_log("Generated image from SVG")
110
  except Exception as e:
111
  debug_log(f"Error converting SVG to PNG: {e}")
112
- import traceback
113
  debug_log(traceback.format_exc())
114
  # Create a simple placeholder image
115
  image = Image.new("RGB", (512, 512), color="#f0f0f0")
@@ -123,7 +155,6 @@ class EndpointHandler:
123
  return image
124
  except Exception as e:
125
  debug_log(f"Error in handler: {e}")
126
- import traceback
127
  debug_log(traceback.format_exc())
128
  # Return a simple error image
129
  image = Image.new("RGB", (512, 512), color="#ff0000")
@@ -131,4 +162,4 @@ class EndpointHandler:
131
  draw = ImageDraw.Draw(image)
132
  draw.text((256, 256), f"Error: {str(e)}", fill="white", anchor="mm")
133
  debug_log("Returning error image")
134
- return image
 
1
  import os
2
  import io
3
  import sys
 
 
4
  import torch
5
  import numpy as np
6
  from PIL import Image
7
+ import traceback
 
 
 
 
8
 
9
  # Add debug logging
10
  def debug_log(message):
 
19
  debug_log("Successfully imported cairosvg")
20
  except ImportError:
21
  debug_log("cairosvg not found. Installing...")
22
+ import subprocess
23
  subprocess.check_call(["pip", "install", "cairosvg", "cairocffi", "cssselect2", "defusedxml", "tinycss2"])
24
  import cairosvg
25
  debug_log("Installed and imported cairosvg")
26
 
27
+ # Add the model directory to the path
28
+ sys.path.append('/code/diffsketcher')
29
+
30
+ # Try to import the model
31
+ try:
32
+ from models.clip_model import ClipModel
33
+ from models.diffusion_model import DiffusionModel
34
+ from models.sketch_model import SketchModel
35
+ debug_log("Successfully imported DiffSketcher models")
36
+ except ImportError as e:
37
+ debug_log(f"Error importing DiffSketcher models: {e}")
38
+ debug_log(traceback.format_exc())
39
+
40
  class EndpointHandler:
41
  def __init__(self, model_dir):
42
  """Initialize the handler with model directory"""
 
46
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
47
  debug_log(f"Using device: {self.device}")
48
 
49
+ # Initialize the model
50
+ try:
51
+ self.clip_model = ClipModel(device=self.device)
52
+ self.diffusion_model = DiffusionModel(device=self.device)
53
+ self.sketch_model = SketchModel(device=self.device)
 
 
54
 
55
+ # Load checkpoint if available
56
+ weights_path = os.path.join(model_dir, "checkpoint.pth")
57
+ if os.path.exists(weights_path):
58
+ debug_log(f"Loading checkpoint from {weights_path}")
59
+ checkpoint = torch.load(weights_path, map_location=self.device)
60
+ self.sketch_model.load_state_dict(checkpoint['sketch_model'])
61
+ debug_log("Successfully loaded checkpoint")
62
+ self.use_model = True
63
+ else:
64
+ debug_log(f"Checkpoint not found at {weights_path}, using model without pre-trained weights")
65
+ self.use_model = True
66
+ except Exception as e:
67
+ debug_log(f"Error initializing model: {e}")
68
+ debug_log(traceback.format_exc())
69
+ self.use_model = False
70
  except Exception as e:
71
  debug_log(f"Error in handler initialization: {e}")
 
72
  debug_log(traceback.format_exc())
73
  self.use_model = False
74
 
 
76
  """Generate an SVG from a text prompt"""
77
  debug_log(f"Generating SVG for prompt: {prompt}")
78
 
79
+ if self.use_model:
80
+ try:
81
+ debug_log("Using initialized model")
82
+
83
+ # Generate SVG using DiffSketcher
84
+ text_features = self.clip_model.encode_text(prompt)
85
+ latent = self.diffusion_model.generate(text_features)
86
+ svg_data = self.sketch_model.generate(latent, num_paths=20, width=width, height=height)
87
+ debug_log("Generated SVG using DiffSketcher")
88
+ return svg_data
89
+ except Exception as e:
90
+ debug_log(f"Error generating SVG with model: {e}")
91
+ debug_log(traceback.format_exc())
92
+ return self._generate_placeholder_svg(prompt, width, height)
93
+ else:
94
+ debug_log("Model not initialized, using placeholder")
95
+ return self._generate_placeholder_svg(prompt, width, height)
96
+
97
+ def _generate_placeholder_svg(self, prompt, width=512, height=512):
98
+ """Generate a placeholder SVG"""
99
+ debug_log(f"Generating placeholder SVG for prompt: {prompt}")
100
+
101
  # Create a more interesting placeholder that looks like a sketch
102
  svg_content = f"""<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg">
103
  <rect width="100%" height="100%" fill="#ffffff"/>
 
111
  <text x="50%" y="50%" font-family="Arial" font-size="20" text-anchor="middle" fill="#333333">{prompt}</text>
112
  </svg>"""
113
 
114
+ debug_log("Generated placeholder SVG")
115
  return svg_content
116
 
117
  def __call__(self, data):
 
142
  debug_log("Generated image from SVG")
143
  except Exception as e:
144
  debug_log(f"Error converting SVG to PNG: {e}")
 
145
  debug_log(traceback.format_exc())
146
  # Create a simple placeholder image
147
  image = Image.new("RGB", (512, 512), color="#f0f0f0")
 
155
  return image
156
  except Exception as e:
157
  debug_log(f"Error in handler: {e}")
 
158
  debug_log(traceback.format_exc())
159
  # Return a simple error image
160
  image = Image.new("RGB", (512, 512), color="#ff0000")
 
162
  draw = ImageDraw.Draw(image)
163
  draw.text((256, 256), f"Error: {str(e)}", fill="white", anchor="mm")
164
  debug_log("Returning error image")
165
+ return image