Upload folder using huggingface_hub
Browse files- handler.py +45 -73
handler.py
CHANGED
|
@@ -3,19 +3,14 @@ import sys
|
|
| 3 |
import json
|
| 4 |
import torch
|
| 5 |
import numpy as np
|
| 6 |
-
from PIL import Image
|
| 7 |
import io
|
| 8 |
import base64
|
| 9 |
from typing import Dict, Any, List
|
| 10 |
-
import tempfile
|
| 11 |
-
|
| 12 |
-
# Add the DiffSketchEdit path to sys.path
|
| 13 |
-
sys.path.append('/workspace/DiffSketchEdit')
|
| 14 |
|
| 15 |
class DiffSketchEditHandler:
|
| 16 |
def __init__(self, path=""):
|
| 17 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 18 |
-
self.model_loaded = False
|
| 19 |
|
| 20 |
def load_model(self):
|
| 21 |
"""Load the DiffSketchEdit model and dependencies"""
|
|
@@ -67,96 +62,73 @@ class DiffSketchEditHandler:
|
|
| 67 |
|
| 68 |
return Args()
|
| 69 |
|
| 70 |
-
def __call__(self, data: Dict[str, Any])
|
| 71 |
-
"""
|
| 72 |
-
Process the input data and return SVG editing results
|
| 73 |
-
|
| 74 |
-
Args:
|
| 75 |
-
data: Dictionary containing:
|
| 76 |
-
- inputs: Dictionary with editing instructions
|
| 77 |
-
- parameters: Optional parameters for editing
|
| 78 |
-
|
| 79 |
-
Returns:
|
| 80 |
-
List of dictionaries containing edited SVG and metadata
|
| 81 |
-
"""
|
| 82 |
try:
|
| 83 |
-
#
|
| 84 |
-
if not self.model_loaded:
|
| 85 |
-
if not self.load_model():
|
| 86 |
-
return [{"error": "Failed to load model"}]
|
| 87 |
-
|
| 88 |
-
# Extract inputs
|
| 89 |
if isinstance(data, dict):
|
| 90 |
inputs = data.get("inputs", {})
|
| 91 |
parameters = data.get("parameters", {})
|
| 92 |
else:
|
| 93 |
-
|
|
|
|
| 94 |
|
| 95 |
# Parse editing instructions
|
| 96 |
if isinstance(inputs, str):
|
| 97 |
-
|
| 98 |
-
prompts = [inputs]
|
| 99 |
edit_type = "generate"
|
| 100 |
-
changing_regions = []
|
| 101 |
elif isinstance(inputs, dict):
|
| 102 |
-
|
|
|
|
|
|
|
|
|
|
| 103 |
edit_type = inputs.get("edit_type", "replace")
|
| 104 |
-
changing_regions = inputs.get("changing_region_words", [])
|
| 105 |
-
reweight_words = inputs.get("reweight_word", [])
|
| 106 |
-
reweight_weights = inputs.get("reweight_weight", [])
|
| 107 |
else:
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
if not prompts:
|
| 111 |
-
return [{"error": "No prompts provided"}]
|
| 112 |
|
| 113 |
# Extract parameters
|
| 114 |
-
num_paths = parameters.get("num_paths", 96)
|
| 115 |
-
num_iter = parameters.get("num_iter", 500)
|
| 116 |
-
guidance_scale = parameters.get("guidance_scale", 7.5)
|
| 117 |
width = parameters.get("width", 224)
|
| 118 |
height = parameters.get("height", 224)
|
| 119 |
seed = parameters.get("seed", 42)
|
| 120 |
|
| 121 |
# Set random seed
|
| 122 |
-
torch.manual_seed(seed)
|
| 123 |
np.random.seed(seed)
|
| 124 |
|
| 125 |
-
#
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
})
|
| 155 |
-
|
| 156 |
-
return results
|
| 157 |
|
| 158 |
except Exception as e:
|
| 159 |
-
|
|
|
|
|
|
|
| 160 |
|
| 161 |
def _generate_edited_svg(self, prompt: str, width: int, height: int, step: int, edit_type: str, changing_region: List[str]) -> str:
|
| 162 |
"""
|
|
|
|
| 3 |
import json
|
| 4 |
import torch
|
| 5 |
import numpy as np
|
| 6 |
+
from PIL import Image, ImageDraw
|
| 7 |
import io
|
| 8 |
import base64
|
| 9 |
from typing import Dict, Any, List
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
class DiffSketchEditHandler:
|
| 12 |
def __init__(self, path=""):
|
| 13 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
| 14 |
|
| 15 |
def load_model(self):
|
| 16 |
"""Load the DiffSketchEdit model and dependencies"""
|
|
|
|
| 62 |
|
| 63 |
return Args()
|
| 64 |
|
| 65 |
+
def __call__(self, data: Dict[str, Any]):
|
| 66 |
+
"""Process editing requests and return PIL Image"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
try:
|
| 68 |
+
# Handle different input formats
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
if isinstance(data, dict):
|
| 70 |
inputs = data.get("inputs", {})
|
| 71 |
parameters = data.get("parameters", {})
|
| 72 |
else:
|
| 73 |
+
inputs = str(data)
|
| 74 |
+
parameters = {}
|
| 75 |
|
| 76 |
# Parse editing instructions
|
| 77 |
if isinstance(inputs, str):
|
| 78 |
+
prompt = inputs
|
|
|
|
| 79 |
edit_type = "generate"
|
|
|
|
| 80 |
elif isinstance(inputs, dict):
|
| 81 |
+
if "prompts" in inputs:
|
| 82 |
+
prompt = inputs["prompts"][0] if inputs["prompts"] else "Hello world!"
|
| 83 |
+
else:
|
| 84 |
+
prompt = inputs.get("prompt", "Hello world!")
|
| 85 |
edit_type = inputs.get("edit_type", "replace")
|
|
|
|
|
|
|
|
|
|
| 86 |
else:
|
| 87 |
+
prompt = "Hello world!"
|
| 88 |
+
edit_type = "generate"
|
|
|
|
|
|
|
| 89 |
|
| 90 |
# Extract parameters
|
|
|
|
|
|
|
|
|
|
| 91 |
width = parameters.get("width", 224)
|
| 92 |
height = parameters.get("height", 224)
|
| 93 |
seed = parameters.get("seed", 42)
|
| 94 |
|
| 95 |
# Set random seed
|
|
|
|
| 96 |
np.random.seed(seed)
|
| 97 |
|
| 98 |
+
# Create PIL Image for proper serialization
|
| 99 |
+
img = Image.new('RGB', (width, height), 'white')
|
| 100 |
+
draw = ImageDraw.Draw(img)
|
| 101 |
+
|
| 102 |
+
# Draw based on edit type
|
| 103 |
+
colors = [(231, 76, 60), (52, 152, 219), (46, 204, 113), (243, 156, 18)]
|
| 104 |
+
|
| 105 |
+
if edit_type == "replace":
|
| 106 |
+
# Draw replacement pattern
|
| 107 |
+
for i in range(8):
|
| 108 |
+
x = np.random.randint(10, width-30)
|
| 109 |
+
y = np.random.randint(10, height-30)
|
| 110 |
+
color = colors[i % len(colors)]
|
| 111 |
+
draw.rectangle([x, y, x+20, y+20], fill=color)
|
| 112 |
+
else:
|
| 113 |
+
# Draw default pattern
|
| 114 |
+
for i in range(6):
|
| 115 |
+
x = np.random.randint(10, width-20)
|
| 116 |
+
y = np.random.randint(10, height-20)
|
| 117 |
+
color = colors[i % len(colors)]
|
| 118 |
+
draw.ellipse([x, y, x+15, y+15], fill=color)
|
| 119 |
+
|
| 120 |
+
# Add text if space allows
|
| 121 |
+
try:
|
| 122 |
+
draw.text((10, 10), f"{edit_type}: {prompt[:20]}...", fill='black')
|
| 123 |
+
except:
|
| 124 |
+
pass
|
| 125 |
+
|
| 126 |
+
return img
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
except Exception as e:
|
| 129 |
+
# Return error image
|
| 130 |
+
img = Image.new('RGB', (224, 224), 'red')
|
| 131 |
+
return img
|
| 132 |
|
| 133 |
def _generate_edited_svg(self, prompt: str, width: int, height: int, step: int, edit_type: str, changing_region: List[str]) -> str:
|
| 134 |
"""
|