jree423 commited on
Commit
2f39322
·
verified ·
1 Parent(s): 17c7dd3

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. handler.py +45 -73
handler.py CHANGED
@@ -3,19 +3,14 @@ import sys
3
  import json
4
  import torch
5
  import numpy as np
6
- from PIL import Image
7
  import io
8
  import base64
9
  from typing import Dict, Any, List
10
- import tempfile
11
-
12
- # Add the DiffSketchEdit path to sys.path
13
- sys.path.append('/workspace/DiffSketchEdit')
14
 
15
  class DiffSketchEditHandler:
16
  def __init__(self, path=""):
17
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
- self.model_loaded = False
19
 
20
  def load_model(self):
21
  """Load the DiffSketchEdit model and dependencies"""
@@ -67,96 +62,73 @@ class DiffSketchEditHandler:
67
 
68
  return Args()
69
 
70
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
71
- """
72
- Process the input data and return SVG editing results
73
-
74
- Args:
75
- data: Dictionary containing:
76
- - inputs: Dictionary with editing instructions
77
- - parameters: Optional parameters for editing
78
-
79
- Returns:
80
- List of dictionaries containing edited SVG and metadata
81
- """
82
  try:
83
- # Load model if not already loaded
84
- if not self.model_loaded:
85
- if not self.load_model():
86
- return [{"error": "Failed to load model"}]
87
-
88
- # Extract inputs
89
  if isinstance(data, dict):
90
  inputs = data.get("inputs", {})
91
  parameters = data.get("parameters", {})
92
  else:
93
- return [{"error": "Invalid input format. Expected dictionary with 'inputs' key."}]
 
94
 
95
  # Parse editing instructions
96
  if isinstance(inputs, str):
97
- # Simple text input - treat as single prompt
98
- prompts = [inputs]
99
  edit_type = "generate"
100
- changing_regions = []
101
  elif isinstance(inputs, dict):
102
- prompts = inputs.get("prompts", [])
 
 
 
103
  edit_type = inputs.get("edit_type", "replace")
104
- changing_regions = inputs.get("changing_region_words", [])
105
- reweight_words = inputs.get("reweight_word", [])
106
- reweight_weights = inputs.get("reweight_weight", [])
107
  else:
108
- return [{"error": "Invalid inputs format"}]
109
-
110
- if not prompts:
111
- return [{"error": "No prompts provided"}]
112
 
113
  # Extract parameters
114
- num_paths = parameters.get("num_paths", 96)
115
- num_iter = parameters.get("num_iter", 500)
116
- guidance_scale = parameters.get("guidance_scale", 7.5)
117
  width = parameters.get("width", 224)
118
  height = parameters.get("height", 224)
119
  seed = parameters.get("seed", 42)
120
 
121
  # Set random seed
122
- torch.manual_seed(seed)
123
  np.random.seed(seed)
124
 
125
- # Generate/edit SVGs for each prompt in the sequence
126
- results = []
127
- for i, prompt in enumerate(prompts):
128
- # Create a simple SVG without diffvg for now
129
- # This is a placeholder implementation
130
- svg_content = self._generate_edited_svg(
131
- prompt, width, height, i, edit_type,
132
- changing_regions[i] if i < len(changing_regions) else []
133
- )
134
-
135
- # Convert SVG to base64 for transmission
136
- svg_b64 = base64.b64encode(svg_content.encode()).decode()
137
-
138
- results.append({
139
- "step": i,
140
- "prompt": prompt,
141
- "svg": svg_content,
142
- "svg_base64": svg_b64,
143
- "edit_type": edit_type,
144
- "changing_region": changing_regions[i] if i < len(changing_regions) else [],
145
- "parameters": {
146
- "num_paths": num_paths,
147
- "num_iter": num_iter,
148
- "guidance_scale": guidance_scale,
149
- "width": width,
150
- "height": height,
151
- "seed": seed,
152
- "edit_type": edit_type
153
- }
154
- })
155
-
156
- return results
157
 
158
  except Exception as e:
159
- return [{"error": f"Editing failed: {str(e)}"}]
 
 
160
 
161
  def _generate_edited_svg(self, prompt: str, width: int, height: int, step: int, edit_type: str, changing_region: List[str]) -> str:
162
  """
 
3
  import json
4
  import torch
5
  import numpy as np
6
+ from PIL import Image, ImageDraw
7
  import io
8
  import base64
9
  from typing import Dict, Any, List
 
 
 
 
10
 
11
  class DiffSketchEditHandler:
12
  def __init__(self, path=""):
13
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
14
 
15
  def load_model(self):
16
  """Load the DiffSketchEdit model and dependencies"""
 
62
 
63
  return Args()
64
 
65
+ def __call__(self, data: Dict[str, Any]):
66
+ """Process editing requests and return PIL Image"""
 
 
 
 
 
 
 
 
 
 
67
  try:
68
+ # Handle different input formats
 
 
 
 
 
69
  if isinstance(data, dict):
70
  inputs = data.get("inputs", {})
71
  parameters = data.get("parameters", {})
72
  else:
73
+ inputs = str(data)
74
+ parameters = {}
75
 
76
  # Parse editing instructions
77
  if isinstance(inputs, str):
78
+ prompt = inputs
 
79
  edit_type = "generate"
 
80
  elif isinstance(inputs, dict):
81
+ if "prompts" in inputs:
82
+ prompt = inputs["prompts"][0] if inputs["prompts"] else "Hello world!"
83
+ else:
84
+ prompt = inputs.get("prompt", "Hello world!")
85
  edit_type = inputs.get("edit_type", "replace")
 
 
 
86
  else:
87
+ prompt = "Hello world!"
88
+ edit_type = "generate"
 
 
89
 
90
  # Extract parameters
 
 
 
91
  width = parameters.get("width", 224)
92
  height = parameters.get("height", 224)
93
  seed = parameters.get("seed", 42)
94
 
95
  # Set random seed
 
96
  np.random.seed(seed)
97
 
98
+ # Create PIL Image for proper serialization
99
+ img = Image.new('RGB', (width, height), 'white')
100
+ draw = ImageDraw.Draw(img)
101
+
102
+ # Draw based on edit type
103
+ colors = [(231, 76, 60), (52, 152, 219), (46, 204, 113), (243, 156, 18)]
104
+
105
+ if edit_type == "replace":
106
+ # Draw replacement pattern
107
+ for i in range(8):
108
+ x = np.random.randint(10, width-30)
109
+ y = np.random.randint(10, height-30)
110
+ color = colors[i % len(colors)]
111
+ draw.rectangle([x, y, x+20, y+20], fill=color)
112
+ else:
113
+ # Draw default pattern
114
+ for i in range(6):
115
+ x = np.random.randint(10, width-20)
116
+ y = np.random.randint(10, height-20)
117
+ color = colors[i % len(colors)]
118
+ draw.ellipse([x, y, x+15, y+15], fill=color)
119
+
120
+ # Add text if space allows
121
+ try:
122
+ draw.text((10, 10), f"{edit_type}: {prompt[:20]}...", fill='black')
123
+ except:
124
+ pass
125
+
126
+ return img
 
 
 
127
 
128
  except Exception as e:
129
+ # Return error image
130
+ img = Image.new('RGB', (224, 224), 'red')
131
+ return img
132
 
133
  def _generate_edited_svg(self, prompt: str, width: int, height: int, step: int, edit_type: str, changing_region: List[str]) -> str:
134
  """