jree423 commited on
Commit
326670f
·
verified ·
1 Parent(s): 5560bac

Upload handler.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. handler.py +61 -23
handler.py CHANGED
@@ -7,15 +7,13 @@ from PIL import Image
7
  import traceback
8
  import json
9
  import logging
 
10
 
11
  # Configure logging
12
  logging.basicConfig(level=logging.INFO,
13
  format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
14
  logger = logging.getLogger(__name__)
15
 
16
- # Add the model directory to the path
17
- sys.path.append('/code/diffsketcher_edit')
18
-
19
  # Safely import cairosvg with fallback
20
  try:
21
  import cairosvg
@@ -29,35 +27,63 @@ except ImportError:
29
 
30
  class EndpointHandler:
31
  def __init__(self, model_dir):
32
- # Initialize the handler with model directory
33
  logger.info(f"Initializing handler with model_dir: {model_dir}")
34
  self.model_dir = model_dir
35
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
  logger.info(f"Using device: {self.device}")
37
 
38
  # Initialize the model
39
- logger.info("Initializing diffsketcher_edit model...")
40
- self.model = self._initialize_model()
41
- logger.info("diffsketcher_edit model initialized")
42
 
43
  def _initialize_model(self):
44
- # Initialize the diffsketcher_edit model
45
- # This is a placeholder for the actual model initialization
46
- # In a real implementation, you would load the model weights and initialize the model
47
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- def generate_svg(self, prompt, width=512, height=512, num_paths=512, seed=None):
50
- # Generate an SVG from a text prompt
51
  logger.info(f"Generating SVG for prompt: {prompt}")
52
 
53
- # This is a placeholder for the actual SVG generation
54
- # In a real implementation, you would use the model to generate an SVG
55
- svg_content = f'<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg"><text x="50%" y="50%" dominant-baseline="middle" text-anchor="middle" font-size="20">{prompt}</text></svg>'
 
 
 
 
 
 
 
 
 
56
 
57
  return svg_content
58
 
59
  def __call__(self, data):
60
- # Handle a request to the model
61
  try:
62
  logger.info(f"Handling request with data: {data}")
63
 
@@ -84,13 +110,25 @@ class EndpointHandler:
84
  logger.info(f"Extracted parameters: {params}")
85
 
86
  # Extract parameters
87
- width = params.get("width", 512)
88
- height = params.get("height", 512)
89
- num_paths = params.get("num_paths", 512)
90
  seed = params.get("seed", None)
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  # Generate SVG
93
- svg_content = self.generate_svg(prompt, width, height, num_paths, seed)
94
  logger.info("SVG content generated")
95
 
96
  # Convert SVG to PNG
@@ -99,11 +137,11 @@ class EndpointHandler:
99
  image = Image.open(io.BytesIO(png_data))
100
  logger.info(f"Converted to PNG with size: {image.size}")
101
 
102
- # Return the PIL Image directly
103
  return image
104
  except Exception as e:
105
  logger.error(f"Error in handler: {e}")
106
  logger.error(traceback.format_exc())
107
  # Return an error image
108
  error_image = Image.new('RGB', (512, 512), color='red')
109
- return error_image
 
7
  import traceback
8
  import json
9
  import logging
10
+ import base64
11
 
12
  # Configure logging
13
  logging.basicConfig(level=logging.INFO,
14
  format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
15
  logger = logging.getLogger(__name__)
16
 
 
 
 
17
  # Safely import cairosvg with fallback
18
  try:
19
  import cairosvg
 
27
 
28
  class EndpointHandler:
29
  def __init__(self, model_dir):
30
+ """Initialize the handler with model directory"""
31
  logger.info(f"Initializing handler with model_dir: {model_dir}")
32
  self.model_dir = model_dir
33
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
  logger.info(f"Using device: {self.device}")
35
 
36
  # Initialize the model
37
+ logger.info("Initializing DiffSketchEdit model...")
38
+ self._initialize_model()
39
+ logger.info("DiffSketchEdit model initialized")
40
 
41
  def _initialize_model(self):
42
+ """Initialize the DiffSketchEdit model"""
43
+ # This is a simplified initialization that doesn't rely on external imports
44
+ logger.info("Using simplified model initialization")
45
+
46
+ # Add the current directory to the path
47
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
48
+
49
+ # Try to import CLIP
50
+ try:
51
+ import clip
52
+ logger.info("Successfully imported CLIP")
53
+ except ImportError:
54
+ logger.warning("CLIP not found. Installing...")
55
+ subprocess.check_call(["pip", "install", "git+https://github.com/openai/CLIP.git"])
56
+ import clip
57
+ logger.info("Successfully installed and imported CLIP")
58
+
59
+ # Try to import diffvg
60
+ try:
61
+ import diffvg
62
+ logger.info("Successfully imported diffvg")
63
+ except ImportError:
64
+ logger.warning("diffvg not found. Using placeholder implementation")
65
 
66
+ def generate_svg(self, prompt, source_image=None, width=512, height=512, num_paths=512, seed=None):
67
+ """Generate an SVG from a text prompt and optionally a source image"""
68
  logger.info(f"Generating SVG for prompt: {prompt}")
69
 
70
+ # Set a seed for reproducibility
71
+ if seed is not None:
72
+ torch.manual_seed(seed)
73
+ np.random.seed(seed)
74
+
75
+ # Create a simple SVG with the prompt text
76
+ # In a real implementation, this would use the DiffSketchEdit model
77
+ svg_content = f'''<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg">
78
+ <rect width="100%" height="100%" fill="#fff0f5"/>
79
+ <text x="50%" y="50%" dominant-baseline="middle" text-anchor="middle" font-size="20" fill="#cc0066">{prompt}</text>
80
+ <text x="50%" y="70%" dominant-baseline="middle" text-anchor="middle" font-size="14" fill="#666">DiffSketchEdit placeholder output</text>
81
+ </svg>'''
82
 
83
  return svg_content
84
 
85
  def __call__(self, data):
86
+ """Handle a request to the model"""
87
  try:
88
  logger.info(f"Handling request with data: {data}")
89
 
 
110
  logger.info(f"Extracted parameters: {params}")
111
 
112
  # Extract parameters
113
+ width = int(params.get("width", 512))
114
+ height = int(params.get("height", 512))
115
+ num_paths = int(params.get("num_paths", 512))
116
  seed = params.get("seed", None)
117
+ if seed is not None:
118
+ seed = int(seed)
119
+
120
+ # Extract source image if provided
121
+ source_image = None
122
+ if "image" in params:
123
+ try:
124
+ image_data = base64.b64decode(params["image"])
125
+ source_image = Image.open(io.BytesIO(image_data))
126
+ logger.info(f"Extracted source image with size: {source_image.size}")
127
+ except Exception as e:
128
+ logger.error(f"Error extracting source image: {e}")
129
 
130
  # Generate SVG
131
+ svg_content = self.generate_svg(prompt, source_image, width, height, num_paths, seed)
132
  logger.info("SVG content generated")
133
 
134
  # Convert SVG to PNG
 
137
  image = Image.open(io.BytesIO(png_data))
138
  logger.info(f"Converted to PNG with size: {image.size}")
139
 
140
+ # Return the image
141
  return image
142
  except Exception as e:
143
  logger.error(f"Error in handler: {e}")
144
  logger.error(traceback.format_exc())
145
  # Return an error image
146
  error_image = Image.new('RGB', (512, 512), color='red')
147
+ return error_image