File size: 8,672 Bytes
34ebd85
 
023c371
 
 
34ebd85
023c371
 
37078c5
 
023c371
 
 
 
34ebd85
a021169
51e82cd
023c371
34ebd85
023c371
34ebd85
 
 
023c371
 
34ebd85
023c371
 
 
 
 
34ebd85
023c371
 
 
 
 
 
34ebd85
023c371
34ebd85
 
023c371
 
 
 
317aa01
023c371
34ebd85
023c371
 
 
 
 
 
 
 
 
34ebd85
54f01ed
51e82cd
023c371
 
9ee24ce
317aa01
023c371
54f01ed
9ee24ce
8419185
023c371
 
 
34ebd85
 
317aa01
023c371
 
 
 
 
 
 
34ebd85
37078c5
023c371
 
8d52022
55dc40f
023c371
 
317aa01
023c371
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317aa01
023c371
 
317aa01
023c371
8419185
023c371
8419185
023c371
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317aa01
023c371
 
 
 
 
8419185
023c371
 
 
 
 
 
 
 
 
 
 
 
 
 
8419185
023c371
 
 
 
 
 
 
 
 
8419185
023c371
 
 
8419185
023c371
 
 
8419185
023c371
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import os
import sys
import tempfile
import shutil
from pathlib import Path
import torch
import yaml
from omegaconf import OmegaConf
from PIL import Image
import io
import cairosvg

# Add DiffSketcher modules to path
sys.path.append('/workspace/DiffSketcher')

class EndpointHandler:
    def __init__(self, path=""):
        """Initialize DiffSketcher model for Hugging Face Inference API"""
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Initializing DiffSketcher on {self.device}")
        
        try:
            # Import DiffSketcher modules
            from libs.engine import ModelState
            from methods.painter.diffsketcher import DiffSketcher
            
            # Load configuration
            config_path = Path(path) / "config" / "diffsketcher.yaml"
            if not config_path.exists():
                # Use default config
                config_path = Path(__file__).parent / "config" / "diffsketcher.yaml"
            
            with open(config_path, 'r') as f:
                self.config = OmegaConf.load(f)
            
            # Initialize model components
            self.model_state = ModelState(self.config)
            self.painter = DiffSketcher(self.config, self.device, self.model_state)
            
            print("DiffSketcher initialized successfully")
            
        except Exception as e:
            print(f"Error initializing DiffSketcher: {e}")
            # Fall back to simple SVG generation
            self.painter = None
            self.config = None
    
    def __call__(self, data):
        """
        Generate sketch image from text prompt
        
        Args:
            data (dict): Input data containing:
                - inputs (str): Text prompt
                - parameters (dict): Generation parameters
        
        Returns:
            PIL.Image.Image: Generated sketch image
        """
        try:
            # Extract inputs
            prompt = data.get("inputs", "")
            parameters = data.get("parameters", {})
            
            if not prompt:
                return self._create_error_image("No prompt provided")
            
            # Extract parameters
            num_paths = parameters.get("num_paths", 96)
            num_iter = parameters.get("num_iter", 500)
            guidance_scale = parameters.get("guidance_scale", 7.5)
            seed = parameters.get("seed", 42)
            width = parameters.get("width", 224)
            height = parameters.get("height", 224)
            
            # Generate SVG
            if self.painter is not None:
                svg_content = self._generate_with_diffsketcher(
                    prompt, num_paths, num_iter, guidance_scale, seed
                )
            else:
                svg_content = self._generate_fallback_svg(prompt, width, height)
            
            # Convert SVG to PIL Image
            image = self._svg_to_image(svg_content, width, height)
            return image
            
        except Exception as e:
            print(f"Error in DiffSketcher inference: {e}")
            return self._create_error_image(f"Error: {str(e)[:50]}")
    
    def _generate_with_diffsketcher(self, prompt, num_paths, num_iter, guidance_scale, seed):
        """Generate SVG using actual DiffSketcher model"""
        try:
            # Set random seed
            torch.manual_seed(seed)
            
            # Create temporary directory for output
            with tempfile.TemporaryDirectory() as temp_dir:
                output_dir = Path(temp_dir) / "output"
                output_dir.mkdir(exist_ok=True)
                
                # Update config with parameters
                config = self.config.copy()
                config.num_paths = num_paths
                config.num_iter = num_iter
                config.guidance_scale = guidance_scale
                config.prompt = prompt
                config.output_dir = str(output_dir)
                
                # Generate sketch
                self.painter.paint(
                    prompt=prompt,
                    output_dir=str(output_dir),
                    num_paths=num_paths,
                    num_iter=num_iter
                )
                
                # Find generated SVG file
                svg_files = list(output_dir.glob("*.svg"))
                if svg_files:
                    with open(svg_files[0], 'r') as f:
                        return f.read()
                else:
                    raise Exception("No SVG file generated")
                    
        except Exception as e:
            print(f"DiffSketcher generation failed: {e}")
            return self._generate_fallback_svg(prompt, 224, 224)
    
    def _generate_fallback_svg(self, prompt, width, height):
        """Generate simple SVG when model fails"""
        import random
        import math
        
        # Set seed for reproducibility
        random.seed(hash(prompt) % 1000)
        
        svg_parts = [f'<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg">']
        svg_parts.append(f'<rect width="{width}" height="{height}" fill="white"/>')
        
        # Generate sketch based on prompt keywords
        prompt_lower = prompt.lower()
        cx, cy = width // 2, height // 2
        
        if any(word in prompt_lower for word in ['car', 'vehicle', 'automobile']):
            # Simple car sketch
            svg_parts.extend([
                f'<rect x="{cx-60}" y="{cy-20}" width="120" height="40" fill="none" stroke="black" stroke-width="2"/>',
                f'<rect x="{cx-40}" y="{cy-40}" width="80" height="20" fill="none" stroke="black" stroke-width="2"/>',
                f'<circle cx="{cx-35}" cy="{cy+20}" r="10" fill="none" stroke="black" stroke-width="2"/>',
                f'<circle cx="{cx+35}" cy="{cy+20}" r="10" fill="none" stroke="black" stroke-width="2"/>'
            ])
        elif any(word in prompt_lower for word in ['house', 'building', 'home']):
            # Simple house sketch
            svg_parts.extend([
                f'<rect x="{cx-50}" y="{cy-10}" width="100" height="50" fill="none" stroke="black" stroke-width="2"/>',
                f'<polygon points="{cx-60},{cy-10} {cx},{cy-50} {cx+60},{cy-10}" fill="none" stroke="black" stroke-width="2"/>',
                f'<rect x="{cx-15}" y="{cy+10}" width="30" height="30" fill="none" stroke="black" stroke-width="2"/>',
                f'<rect x="{cx-40}" y="{cy-5}" width="15" height="15" fill="none" stroke="black" stroke-width="1"/>',
                f'<rect x="{cx+25}" y="{cy-5}" width="15" height="15" fill="none" stroke="black" stroke-width="1"/>'
            ])
        else:
            # Abstract sketch
            for i in range(5):
                x = random.randint(20, width-20)
                y = random.randint(20, height-20)
                size = random.randint(10, 30)
                
                if i % 3 == 0:
                    svg_parts.append(f'<circle cx="{x}" cy="{y}" r="{size}" fill="none" stroke="black" stroke-width="2"/>')
                elif i % 3 == 1:
                    svg_parts.append(f'<rect x="{x-size}" y="{y-size}" width="{size*2}" height="{size*2}" fill="none" stroke="black" stroke-width="2"/>')
                else:
                    points = []
                    for j in range(3):
                        px = x + size * math.cos(j * 120 * math.pi / 180)
                        py = y + size * math.sin(j * 120 * math.pi / 180)
                        points.append(f"{px},{py}")
                    svg_parts.append(f'<polygon points="{" ".join(points)}" fill="none" stroke="black" stroke-width="2"/>')
        
        svg_parts.append('</svg>')
        return '\n'.join(svg_parts)
    
    def _svg_to_image(self, svg_content, width=224, height=224):
        """Convert SVG to PIL Image"""
        try:
            # Convert SVG to PNG using cairosvg
            png_data = cairosvg.svg2png(
                bytestring=svg_content.encode('utf-8'),
                output_width=width,
                output_height=height
            )
            
            # Convert to PIL Image
            image = Image.open(io.BytesIO(png_data))
            return image.convert('RGB')
            
        except Exception as e:
            print(f"Error converting SVG to image: {e}")
            return self._create_error_image("SVG conversion failed")
    
    def _create_error_image(self, message, width=224, height=224):
        """Create error image"""
        image = Image.new('RGB', (width, height), 'white')
        return image