|
|
import os |
|
|
import sys |
|
|
import tempfile |
|
|
import shutil |
|
|
from pathlib import Path |
|
|
import torch |
|
|
import yaml |
|
|
from omegaconf import OmegaConf |
|
|
from PIL import Image |
|
|
import io |
|
|
import cairosvg |
|
|
|
|
|
|
|
|
sys.path.append('/workspace/DiffSketcher') |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path=""): |
|
|
"""Initialize DiffSketcher model for Hugging Face Inference API""" |
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print(f"Initializing DiffSketcher on {self.device}") |
|
|
|
|
|
try: |
|
|
|
|
|
from libs.engine import ModelState |
|
|
from methods.painter.diffsketcher import DiffSketcher |
|
|
|
|
|
|
|
|
config_path = Path(path) / "config" / "diffsketcher.yaml" |
|
|
if not config_path.exists(): |
|
|
|
|
|
config_path = Path(__file__).parent / "config" / "diffsketcher.yaml" |
|
|
|
|
|
with open(config_path, 'r') as f: |
|
|
self.config = OmegaConf.load(f) |
|
|
|
|
|
|
|
|
self.model_state = ModelState(self.config) |
|
|
self.painter = DiffSketcher(self.config, self.device, self.model_state) |
|
|
|
|
|
print("DiffSketcher initialized successfully") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error initializing DiffSketcher: {e}") |
|
|
|
|
|
self.painter = None |
|
|
self.config = None |
|
|
|
|
|
def __call__(self, data): |
|
|
""" |
|
|
Generate sketch image from text prompt |
|
|
|
|
|
Args: |
|
|
data (dict): Input data containing: |
|
|
- inputs (str): Text prompt |
|
|
- parameters (dict): Generation parameters |
|
|
|
|
|
Returns: |
|
|
PIL.Image.Image: Generated sketch image |
|
|
""" |
|
|
try: |
|
|
|
|
|
prompt = data.get("inputs", "") |
|
|
parameters = data.get("parameters", {}) |
|
|
|
|
|
if not prompt: |
|
|
return self._create_error_image("No prompt provided") |
|
|
|
|
|
|
|
|
num_paths = parameters.get("num_paths", 96) |
|
|
num_iter = parameters.get("num_iter", 500) |
|
|
guidance_scale = parameters.get("guidance_scale", 7.5) |
|
|
seed = parameters.get("seed", 42) |
|
|
width = parameters.get("width", 224) |
|
|
height = parameters.get("height", 224) |
|
|
|
|
|
|
|
|
if self.painter is not None: |
|
|
svg_content = self._generate_with_diffsketcher( |
|
|
prompt, num_paths, num_iter, guidance_scale, seed |
|
|
) |
|
|
else: |
|
|
svg_content = self._generate_fallback_svg(prompt, width, height) |
|
|
|
|
|
|
|
|
image = self._svg_to_image(svg_content, width, height) |
|
|
return image |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error in DiffSketcher inference: {e}") |
|
|
return self._create_error_image(f"Error: {str(e)[:50]}") |
|
|
|
|
|
def _generate_with_diffsketcher(self, prompt, num_paths, num_iter, guidance_scale, seed): |
|
|
"""Generate SVG using actual DiffSketcher model""" |
|
|
try: |
|
|
|
|
|
torch.manual_seed(seed) |
|
|
|
|
|
|
|
|
with tempfile.TemporaryDirectory() as temp_dir: |
|
|
output_dir = Path(temp_dir) / "output" |
|
|
output_dir.mkdir(exist_ok=True) |
|
|
|
|
|
|
|
|
config = self.config.copy() |
|
|
config.num_paths = num_paths |
|
|
config.num_iter = num_iter |
|
|
config.guidance_scale = guidance_scale |
|
|
config.prompt = prompt |
|
|
config.output_dir = str(output_dir) |
|
|
|
|
|
|
|
|
self.painter.paint( |
|
|
prompt=prompt, |
|
|
output_dir=str(output_dir), |
|
|
num_paths=num_paths, |
|
|
num_iter=num_iter |
|
|
) |
|
|
|
|
|
|
|
|
svg_files = list(output_dir.glob("*.svg")) |
|
|
if svg_files: |
|
|
with open(svg_files[0], 'r') as f: |
|
|
return f.read() |
|
|
else: |
|
|
raise Exception("No SVG file generated") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"DiffSketcher generation failed: {e}") |
|
|
return self._generate_fallback_svg(prompt, 224, 224) |
|
|
|
|
|
def _generate_fallback_svg(self, prompt, width, height): |
|
|
"""Generate simple SVG when model fails""" |
|
|
import random |
|
|
import math |
|
|
|
|
|
|
|
|
random.seed(hash(prompt) % 1000) |
|
|
|
|
|
svg_parts = [f'<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg">'] |
|
|
svg_parts.append(f'<rect width="{width}" height="{height}" fill="white"/>') |
|
|
|
|
|
|
|
|
prompt_lower = prompt.lower() |
|
|
cx, cy = width // 2, height // 2 |
|
|
|
|
|
if any(word in prompt_lower for word in ['car', 'vehicle', 'automobile']): |
|
|
|
|
|
svg_parts.extend([ |
|
|
f'<rect x="{cx-60}" y="{cy-20}" width="120" height="40" fill="none" stroke="black" stroke-width="2"/>', |
|
|
f'<rect x="{cx-40}" y="{cy-40}" width="80" height="20" fill="none" stroke="black" stroke-width="2"/>', |
|
|
f'<circle cx="{cx-35}" cy="{cy+20}" r="10" fill="none" stroke="black" stroke-width="2"/>', |
|
|
f'<circle cx="{cx+35}" cy="{cy+20}" r="10" fill="none" stroke="black" stroke-width="2"/>' |
|
|
]) |
|
|
elif any(word in prompt_lower for word in ['house', 'building', 'home']): |
|
|
|
|
|
svg_parts.extend([ |
|
|
f'<rect x="{cx-50}" y="{cy-10}" width="100" height="50" fill="none" stroke="black" stroke-width="2"/>', |
|
|
f'<polygon points="{cx-60},{cy-10} {cx},{cy-50} {cx+60},{cy-10}" fill="none" stroke="black" stroke-width="2"/>', |
|
|
f'<rect x="{cx-15}" y="{cy+10}" width="30" height="30" fill="none" stroke="black" stroke-width="2"/>', |
|
|
f'<rect x="{cx-40}" y="{cy-5}" width="15" height="15" fill="none" stroke="black" stroke-width="1"/>', |
|
|
f'<rect x="{cx+25}" y="{cy-5}" width="15" height="15" fill="none" stroke="black" stroke-width="1"/>' |
|
|
]) |
|
|
else: |
|
|
|
|
|
for i in range(5): |
|
|
x = random.randint(20, width-20) |
|
|
y = random.randint(20, height-20) |
|
|
size = random.randint(10, 30) |
|
|
|
|
|
if i % 3 == 0: |
|
|
svg_parts.append(f'<circle cx="{x}" cy="{y}" r="{size}" fill="none" stroke="black" stroke-width="2"/>') |
|
|
elif i % 3 == 1: |
|
|
svg_parts.append(f'<rect x="{x-size}" y="{y-size}" width="{size*2}" height="{size*2}" fill="none" stroke="black" stroke-width="2"/>') |
|
|
else: |
|
|
points = [] |
|
|
for j in range(3): |
|
|
px = x + size * math.cos(j * 120 * math.pi / 180) |
|
|
py = y + size * math.sin(j * 120 * math.pi / 180) |
|
|
points.append(f"{px},{py}") |
|
|
svg_parts.append(f'<polygon points="{" ".join(points)}" fill="none" stroke="black" stroke-width="2"/>') |
|
|
|
|
|
svg_parts.append('</svg>') |
|
|
return '\n'.join(svg_parts) |
|
|
|
|
|
def _svg_to_image(self, svg_content, width=224, height=224): |
|
|
"""Convert SVG to PIL Image""" |
|
|
try: |
|
|
|
|
|
png_data = cairosvg.svg2png( |
|
|
bytestring=svg_content.encode('utf-8'), |
|
|
output_width=width, |
|
|
output_height=height |
|
|
) |
|
|
|
|
|
|
|
|
image = Image.open(io.BytesIO(png_data)) |
|
|
return image.convert('RGB') |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error converting SVG to image: {e}") |
|
|
return self._create_error_image("SVG conversion failed") |
|
|
|
|
|
def _create_error_image(self, message, width=224, height=224): |
|
|
"""Create error image""" |
|
|
image = Image.new('RGB', (width, height), 'white') |
|
|
return image |