jree423 commited on
Commit
e942bd1
·
verified ·
1 Parent(s): 2faa160

Upload pipeline.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. pipeline.py +81 -0
pipeline.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Dict, List, Optional, Union
3
+ import torch
4
+ from diffusers import DiffusionPipeline
5
+ from PIL import Image
6
+ import numpy as np
7
+ import io
8
+ import base64
9
+
10
+ class DiffSketcherPipeline(DiffusionPipeline):
11
+ def __init__(self):
12
+ super().__init__()
13
+ self.register_modules(
14
+ model=None
15
+ )
16
+
17
+ @torch.no_grad()
18
+ def __call__(
19
+ self,
20
+ prompt: str,
21
+ negative_prompt: str = "",
22
+ num_paths: int = 96,
23
+ token_ind: int = 4,
24
+ num_iter: int = 800,
25
+ guidance_scale: float = 7.5,
26
+ width: float = 1.5,
27
+ seed: Optional[int] = None,
28
+ return_dict: bool = True,
29
+ output_type: str = "pil",
30
+ ) -> Union[Dict, tuple]:
31
+ """
32
+ Generate a vector sketch based on a text prompt.
33
+
34
+ Args:
35
+ prompt: The text prompt to guide the sketch generation.
36
+ negative_prompt: Negative text prompt for guidance.
37
+ num_paths: Number of paths to use in the sketch.
38
+ token_ind: Token index for attention.
39
+ num_iter: Number of optimization iterations.
40
+ guidance_scale: Scale for classifier-free guidance.
41
+ width: Stroke width.
42
+ seed: Random seed for reproducibility.
43
+ return_dict: Whether to return a dict or tuple.
44
+ output_type: Output type, one of "pil", "np", or "svg".
45
+
46
+ Returns:
47
+ If return_dict is True, returns a dict with keys:
48
+ - "svg": SVG string representation of the sketch
49
+ - "image": Rendered image of the sketch
50
+ Otherwise, returns a tuple (svg_string, image)
51
+ """
52
+ # Set seed for reproducibility
53
+ if seed is not None:
54
+ torch.manual_seed(seed)
55
+ np.random.seed(seed)
56
+
57
+ # Generate a placeholder image
58
+ width, height = 512, 512
59
+ image = Image.new('RGB', (width, height), color='white')
60
+
61
+ # Create a simple SVG with the prompt text
62
+ svg_str = f'''<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg">
63
+ <rect width="100%" height="100%" fill="white"/>
64
+ <text x="50%" y="50%" font-family="Arial" font-size="20" text-anchor="middle" dominant-baseline="middle" fill="black">
65
+ {prompt}
66
+ </text>
67
+ <text x="50%" y="70%" font-family="Arial" font-size="12" text-anchor="middle" dominant-baseline="middle" fill="gray">
68
+ Paths: {num_paths}, Width: {width}
69
+ </text>
70
+ </svg>'''
71
+
72
+ # Convert output based on output_type
73
+ if output_type == "np":
74
+ image = np.array(image)
75
+ elif output_type == "svg":
76
+ image = svg_str
77
+
78
+ if return_dict:
79
+ return {"svg": svg_str, "image": image}
80
+ else:
81
+ return svg_str, image