primerz commited on
Commit
b851544
·
verified ·
1 Parent(s): 7eecce9

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +49 -0
utils.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from PIL import Image
4
+ from transformers import BlipProcessor, BlipForConditionalGeneration
5
+ import torch
6
+ from config import Config
7
+
8
+ def resize_image_to_1mp(image):
9
+ """Resizes image to approx 1MP (e.g., 1024x1024) preserving aspect ratio."""
10
+ w, h = image.size
11
+ target_pixels = 1024 * 1024
12
+ aspect_ratio = w / h
13
+
14
+ # Calculate new dimensions
15
+ new_h = int((target_pixels / aspect_ratio) ** 0.5)
16
+ new_w = int(new_h * aspect_ratio)
17
+
18
+ # Ensure divisibility by 8 (vae requirement), usually 32 for safety
19
+ new_w = (new_w // 32) * 32
20
+ new_h = (new_h // 32) * 32
21
+
22
+ return image.resize((new_w, new_h), Image.LANCZOS)
23
+
24
+ # Simple caching for captioner
25
+ captioner_processor = None
26
+ captioner_model = None
27
+
28
+ def get_caption(image):
29
+ global captioner_processor, captioner_model
30
+
31
+ if captioner_model is None:
32
+ print("Loading Captioner...")
33
+ captioner_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
34
+ captioner_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(Config.DEVICE)
35
+
36
+ inputs = captioner_processor(image, return_tensors="pt").to(Config.DEVICE)
37
+ out = captioner_model.generate(**inputs)
38
+ caption = captioner_processor.decode(out[0], skip_special_tokens=True)
39
+ return caption
40
+
41
+ def prepare_control_images(image, zoe_detector, lineart_detector):
42
+ """Generates the conditioning maps from the input image."""
43
+ # 1. Zoe Depth Map
44
+ depth_map = zoe_detector(image, detect_resolution=1024, image_resolution=1024)
45
+
46
+ # 2. LineArt Map
47
+ lineart_map = lineart_detector(image, detect_resolution=1024, image_resolution=1024)
48
+
49
+ return depth_map, lineart_map