jree423 commited on
Commit
48cd50a
·
verified ·
1 Parent(s): 8e9cb18

Delete diffsketcher_handler.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. diffsketcher_handler.py +0 -92
diffsketcher_handler.py DELETED
@@ -1,92 +0,0 @@
1
- #!/usr/bin/env python
2
- # -*- coding: utf-8 -*-
3
-
4
- import os
5
- import sys
6
- import torch
7
- import numpy as np
8
- from PIL import Image
9
- import io
10
- import base64
11
- from handler_template import BaseHandler
12
-
13
- # Add DiffSketcher to path
14
- sys.path.append("/app/model")
15
-
16
- class Handler(BaseHandler):
17
- def initialize(self):
18
- """Load the DiffSketcher model"""
19
- try:
20
- from models.clip_text_encoder import CLIPTextEncoder
21
- from models.sketch_generator import SketchGenerator
22
-
23
- # Load text encoder
24
- self.text_encoder = CLIPTextEncoder()
25
- self.text_encoder.to(self.device)
26
- self.text_encoder.eval()
27
-
28
- # Load sketch generator
29
- self.model = SketchGenerator()
30
- weights_path = os.path.join("/app/model/weights", "diffsketcher_model.pth")
31
- if os.path.exists(weights_path):
32
- state_dict = torch.load(weights_path, map_location=self.device)
33
- self.model.load_state_dict(state_dict)
34
- else:
35
- raise FileNotFoundError(f"Model weights not found at {weights_path}")
36
-
37
- self.model.to(self.device)
38
- self.model.eval()
39
-
40
- self.initialized = True
41
- print("DiffSketcher model initialized successfully")
42
- except Exception as e:
43
- print(f"Error initializing DiffSketcher model: {str(e)}")
44
- raise
45
-
46
- def preprocess(self, data):
47
- """Process the input data"""
48
- try:
49
- # Extract prompt from the request
50
- prompt = data.get("prompt", "")
51
- if not prompt:
52
- raise ValueError("No prompt provided in the request")
53
-
54
- # Encode text with CLIP
55
- with torch.no_grad():
56
- text_embedding = self.text_encoder.encode_text(prompt)
57
-
58
- return {
59
- "text_embedding": text_embedding,
60
- "prompt": prompt
61
- }
62
- except Exception as e:
63
- print(f"Error in preprocessing: {str(e)}")
64
- raise
65
-
66
- def inference(self, inputs):
67
- """Generate SVG from text embedding"""
68
- try:
69
- text_embedding = inputs["text_embedding"]
70
-
71
- # Run inference
72
- with torch.no_grad():
73
- svg_data = self.model.generate(text_embedding)
74
-
75
- return svg_data
76
- except Exception as e:
77
- print(f"Error during inference: {str(e)}")
78
- raise
79
-
80
- def postprocess(self, inference_output):
81
- """Format the model output"""
82
- try:
83
- svg_content = inference_output["svg_content"]
84
-
85
- # Return both the SVG content and base64 encoded version
86
- return {
87
- "svg_content": svg_content,
88
- "svg_base64": self.svg_to_base64(svg_content)
89
- }
90
- except Exception as e:
91
- print(f"Error in postprocessing: {str(e)}")
92
- return {"error": str(e)}