jree423 commited on
Commit
a96d3ab
·
verified ·
1 Parent(s): 5edea0f

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. handler.py +70 -0
handler.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ import torch
3
+ import base64
4
+ import io
5
+ from PIL import Image
6
+ from diffusers import DiffusionPipeline
7
+
8
+ class EndpointHandler:
9
+ def __init__(self, path=""):
10
+ # Initialize the model
11
+ self.model = DiffusionPipeline.from_pretrained(path)
12
+ self.model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
13
+
14
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
15
+ """
16
+ Args:
17
+ data: Dictionary with the following structure:
18
+ {
19
+ "prompt": str,
20
+ "negative_prompt": str (optional),
21
+ "num_paths": int (optional),
22
+ "num_iter": int (optional),
23
+ "guidance_scale": float (optional),
24
+ "width": int (optional),
25
+ "seed": int (optional)
26
+ }
27
+ Returns:
28
+ Dictionary with the following structure:
29
+ {
30
+ "svg": str,
31
+ "image": str (base64 encoded image)
32
+ }
33
+ """
34
+ # Extract parameters from the input data
35
+ prompt = data.get("prompt", "")
36
+ negative_prompt = data.get("negative_prompt", None)
37
+ num_paths = data.get("num_paths", 96)
38
+ num_iter = data.get("num_iter", 800)
39
+ guidance_scale = data.get("guidance_scale", 7.5)
40
+ width = data.get("width", 2)
41
+ seed = data.get("seed", None)
42
+
43
+ # Set the seed if provided
44
+ if seed is not None:
45
+ torch.manual_seed(seed)
46
+
47
+ # Generate the SVG
48
+ output = self.model(
49
+ prompt=prompt,
50
+ negative_prompt=negative_prompt,
51
+ num_paths=num_paths,
52
+ num_iter=num_iter,
53
+ guidance_scale=guidance_scale,
54
+ width=width
55
+ )
56
+
57
+ # Get the SVG and image
58
+ svg = output.svg
59
+ image = output.images[0]
60
+
61
+ # Convert the image to base64
62
+ buffered = io.BytesIO()
63
+ image.save(buffered, format="PNG")
64
+ img_str = base64.b64encode(buffered.getvalue()).decode()
65
+
66
+ # Return the results
67
+ return {
68
+ "svg": svg,
69
+ "image": img_str
70
+ }