jree423 commited on
Commit
b52bb75
·
verified ·
1 Parent(s): bee3e4f

Fix: Correct EndpointHandler class name and add robust error handling for model loading

Browse files
Files changed (1) hide show
  1. handler.py +43 -29
handler.py CHANGED
@@ -14,27 +14,33 @@ import torchvision.transforms as transforms
14
  import random
15
  import math
16
 
17
- class SVGDreamerHandler:
18
- def __init__(self):
19
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
  self.model_id = "runwayml/stable-diffusion-v1-5"
21
 
22
- # Initialize the diffusion pipeline
23
- self.pipe = StableDiffusionPipeline.from_pretrained(
24
- self.model_id,
25
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
26
- safety_checker=None,
27
- requires_safety_checker=False
28
- ).to(self.device)
29
-
30
- # Use DDIM scheduler for better control
31
- self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)
32
-
33
- # CLIP model for guidance
34
- self.clip_model = self.pipe.text_encoder
35
- self.clip_tokenizer = self.pipe.tokenizer
36
-
37
- print("SVGDreamer handler initialized successfully!")
 
 
 
 
 
 
38
 
39
  def __call__(self, inputs: Union[str, Dict[str, Any]]) -> Image.Image:
40
  """
@@ -491,18 +497,26 @@ class SVGDreamerHandler:
491
 
492
  def get_text_embeddings(self, prompt: str):
493
  """Get CLIP text embeddings for the prompt"""
494
- with torch.no_grad():
495
- text_inputs = self.clip_tokenizer(
496
- prompt,
497
- padding="max_length",
498
- max_length=self.clip_tokenizer.model_max_length,
499
- truncation=True,
500
- return_tensors="pt"
501
- ).to(self.device)
502
 
503
- text_embeddings = self.clip_model(text_inputs.input_ids)[0]
504
-
505
- return text_embeddings
 
 
 
 
 
 
 
 
 
 
 
 
 
506
 
507
  def extract_semantic_features(self, prompt: str):
508
  """Extract semantic features from prompt"""
 
14
  import random
15
  import math
16
 
17
+ class EndpointHandler:
18
+ def __init__(self, path=""):
19
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
  self.model_id = "runwayml/stable-diffusion-v1-5"
21
 
22
+ try:
23
+ # Initialize the diffusion pipeline
24
+ self.pipe = StableDiffusionPipeline.from_pretrained(
25
+ self.model_id,
26
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
27
+ safety_checker=None,
28
+ requires_safety_checker=False
29
+ ).to(self.device)
30
+
31
+ # Use DDIM scheduler for better control
32
+ self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)
33
+
34
+ # CLIP model for guidance
35
+ self.clip_model = self.pipe.text_encoder
36
+ self.clip_tokenizer = self.pipe.tokenizer
37
+
38
+ print("SVGDreamer handler initialized successfully!")
39
+ except Exception as e:
40
+ print(f"Warning: Could not load diffusion model: {e}")
41
+ self.pipe = None
42
+ self.clip_model = None
43
+ self.clip_tokenizer = None
44
 
45
  def __call__(self, inputs: Union[str, Dict[str, Any]]) -> Image.Image:
46
  """
 
497
 
498
  def get_text_embeddings(self, prompt: str):
499
  """Get CLIP text embeddings for the prompt"""
500
+ if self.clip_model is None or self.clip_tokenizer is None:
501
+ # Return dummy embeddings if model not loaded
502
+ return torch.zeros((1, 77, 768))
 
 
 
 
 
503
 
504
+ try:
505
+ with torch.no_grad():
506
+ text_inputs = self.clip_tokenizer(
507
+ prompt,
508
+ padding="max_length",
509
+ max_length=self.clip_tokenizer.model_max_length,
510
+ truncation=True,
511
+ return_tensors="pt"
512
+ ).to(self.device)
513
+
514
+ text_embeddings = self.clip_model(text_inputs.input_ids)[0]
515
+
516
+ return text_embeddings
517
+ except Exception as e:
518
+ print(f"Error getting text embeddings: {e}")
519
+ return torch.zeros((1, 77, 768))
520
 
521
  def extract_semantic_features(self, prompt: str):
522
  """Extract semantic features from prompt"""