jree423 commited on
Commit
8de5d7e
·
verified ·
1 Parent(s): 3b4f763

Upload handler.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. handler.py +13 -43
handler.py CHANGED
@@ -1,16 +1,11 @@
1
- from typing import Dict, List, Any
 
2
  import torch
3
  import base64
4
  import io
5
  import os
6
  import json
7
  from PIL import Image
8
- from diffusers import DiffusionPipeline
9
-
10
- class DiffSketcherPipeline:
11
- def __init__(self):
12
- # This is a placeholder class that will be replaced by the actual implementation
13
- pass
14
 
15
  class EndpointHandler:
16
  def __init__(self, path=""):
@@ -22,7 +17,7 @@ class EndpointHandler:
22
  else:
23
  # Create a default config
24
  self.config = {
25
- "architecture": "DiffSketcherPipeline",
26
  "format": "diffusers",
27
  "version": "0.1.0"
28
  }
@@ -30,48 +25,22 @@ class EndpointHandler:
30
  with open(model_index_path, "w") as f:
31
  json.dump(self.config, f, indent=2)
32
 
33
- # Initialize a simple pipeline for now
34
- self.model = DiffSketcherPipeline()
35
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
 
37
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
38
- """
39
- Args:
40
- data: Dictionary with the following structure:
41
- {
42
- "prompt": str,
43
- "negative_prompt": str (optional),
44
- "num_paths": int (optional),
45
- "num_iter": int (optional),
46
- "guidance_scale": float (optional),
47
- "width": int (optional),
48
- "seed": int (optional)
49
- }
50
- Returns:
51
- Dictionary with the following structure:
52
- {
53
- "svg": str,
54
- "image": str (base64 encoded image)
55
- }
56
- """
57
- # Extract parameters from the input data
58
  prompt = data.get("prompt", "")
59
- negative_prompt = data.get("negative_prompt", None)
60
- num_paths = data.get("num_paths", 96)
61
- num_iter = data.get("num_iter", 800)
62
- guidance_scale = data.get("guidance_scale", 7.5)
63
- width = data.get("width", 2)
64
- seed = data.get("seed", None)
65
 
66
- # Set the seed if provided
67
- if seed is not None:
68
- torch.manual_seed(seed)
69
-
70
  # Generate a placeholder SVG
71
- svg = f'<svg xmlns="http://www.w3.org/2000/svg" width="512" height="512" viewBox="0 0 512 512"><text x="50%" y="50%" dominant-baseline="middle" text-anchor="middle" font-size="20">DiffSketcher: {prompt}</text></svg>'
 
72
 
73
  # Create a placeholder image
74
- image = Image.new('RGB', (512, 512), color = (73, 109, 137))
75
 
76
  # Convert the image to base64
77
  buffered = io.BytesIO()
@@ -82,4 +51,5 @@ class EndpointHandler:
82
  return {
83
  "svg": svg,
84
  "image": img_str
85
- }
 
 
1
+
2
+ from typing import Dict, Any
3
  import torch
4
  import base64
5
  import io
6
  import os
7
  import json
8
  from PIL import Image
 
 
 
 
 
 
9
 
10
  class EndpointHandler:
11
  def __init__(self, path=""):
 
17
  else:
18
  # Create a default config
19
  self.config = {
20
+ "architecture": "SimplePipeline",
21
  "format": "diffusers",
22
  "version": "0.1.0"
23
  }
 
25
  with open(model_index_path, "w") as f:
26
  json.dump(self.config, f, indent=2)
27
 
28
+ # Initialize device
 
29
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
 
31
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
32
+ # Extract prompt from the input data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  prompt = data.get("prompt", "")
34
+ if not prompt and "prompts" in data:
35
+ prompts = data.get("prompts", [""])
36
+ prompt = prompts[0] if prompts else ""
 
 
 
37
 
 
 
 
 
38
  # Generate a placeholder SVG
39
+ model_name = os.path.basename(os.getcwd())
40
+ svg = f'<svg xmlns="http://www.w3.org/2000/svg" width="512" height="512" viewBox="0 0 512 512"><text x="50%" y="50%" dominant-baseline="middle" text-anchor="middle" font-size="20">{model_name}: {prompt}</text></svg>'
41
 
42
  # Create a placeholder image
43
+ image = Image.new('RGB', (512, 512), color = (100, 100, 100))
44
 
45
  # Convert the image to base64
46
  buffered = io.BytesIO()
 
51
  return {
52
  "svg": svg,
53
  "image": img_str
54
+ }
55
+