colt12 commited on
Commit
fa41512
·
verified ·
1 Parent(s): c1c2cf7

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +29 -64
handler.py CHANGED
@@ -1,68 +1,33 @@
1
- import io
2
- from PIL import Image
3
  import torch
4
- from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer, AutoConfig
 
 
5
 
6
- # Load the model and processors
7
- model_name = "colt12/maxcushion"
8
- try:
9
- print("Loading model configuration...")
10
- config = AutoConfig.from_pretrained(model_name)
11
-
12
- print("Loading model...")
13
- if isinstance(config, VisionEncoderDecoderModel):
14
- model = VisionEncoderDecoderModel.from_pretrained(model_name, config=config)
15
- else:
16
- # If the config is not for VisionEncoderDecoderModel, we might need to construct it manually
17
- encoder_config = AutoConfig.from_pretrained("google/vit-base-patch16-224-in21k")
18
- decoder_config = AutoConfig.from_pretrained("gpt2")
19
- model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
20
- "google/vit-base-patch16-224-in21k",
21
- "gpt2",
22
- encoder_config=encoder_config,
23
- decoder_config=decoder_config
24
- )
25
- model.load_state_dict(torch.load(f"{model_name}/pytorch_model.bin"))
26
-
27
- print("Model loaded successfully.")
28
-
29
- print("Loading image processor...")
30
- image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
31
- print("Image processor loaded successfully.")
32
-
33
- print("Loading tokenizer...")
34
- tokenizer = AutoTokenizer.from_pretrained("gpt2")
35
- print("Tokenizer loaded successfully.")
36
- except Exception as e:
37
- print(f"Error loading model or processors: {str(e)}")
38
- raise
39
 
40
- def predict(image_bytes):
41
- # Open the image using PIL
42
- image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
43
-
44
- # Preprocess the image
45
- pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
46
-
47
- # Generate the caption
48
- with torch.no_grad():
49
- output_ids = model.generate(pixel_values, max_length=50, num_return_sequences=1)
50
- generated_caption = tokenizer.decode(output_ids[0], skip_special_tokens=True)
51
-
52
- return generated_caption
 
 
 
 
 
 
53
 
54
- def inference(inputs):
55
- # Check if the input is a file or raw bytes
56
- if "file" in inputs:
57
- image = inputs["file"]
58
- image_bytes = image.read()
59
- elif "bytes" in inputs:
60
- image_bytes = inputs["bytes"]
61
- else:
62
- raise ValueError("No valid input found. Expected 'file' or 'bytes'.")
63
-
64
- # Generate the caption
65
- result = predict(image_bytes)
66
-
67
- # Return the result
68
- return {"caption": result}
 
 
 
1
  import torch
2
+ from diffusers import StableDiffusionXLPipeline
3
+ import base64
4
+ from io import BytesIO
5
 
6
+ class InferenceHandler:
7
+ def __init__(self):
8
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
9
+ model_name = "colt12/maxcushion"
10
+ self.pipe = StableDiffusionXLPipeline.from_pretrained(model_name, torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
11
+ self.pipe = self.pipe.to(self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ def __call__(self, inputs):
14
+ prompt = inputs.get("prompt", "")
15
+ if not prompt:
16
+ raise ValueError("A prompt must be provided")
17
+
18
+ negative_prompt = inputs.get("negative_prompt", "")
19
+
20
+ image = self.pipe(
21
+ prompt=prompt,
22
+ negative_prompt=negative_prompt,
23
+ num_inference_steps=30,
24
+ guidance_scale=7.5
25
+ ).images[0]
26
+
27
+ buffered = BytesIO()
28
+ image.save(buffered, format="PNG")
29
+ image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
30
+
31
+ return {"image_base64": image_base64}
32
 
33
+ handler = InferenceHandler()