Devashishraghav commited on
Commit
3f4b262
·
verified ·
1 Parent(s): 40ba0e0

Upload processor/upscale.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. processor/upscale.py +126 -82
processor/upscale.py CHANGED
@@ -1,116 +1,160 @@
1
  import cv2
2
  import numpy as np
3
  import os
4
- import io
5
- import base64
6
- from PIL import Image
7
- from dotenv import load_dotenv
8
- from google import genai
9
- from google.genai import types
10
 
11
- load_dotenv()
 
 
 
 
 
 
 
 
 
 
12
 
13
- # Initialize the Gemini client
14
- _client = None
15
 
16
- def get_client():
17
- global _client
18
- if _client is None:
19
- api_key = os.getenv("GEMINI_API_KEY")
20
- if not api_key or api_key == "your_api_key_here":
21
- raise ValueError("GEMINI_API_KEY not set. Please add it to your .env file.")
22
- _client = genai.Client(api_key=api_key)
23
- return _client
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
 
26
  def upscale_image(img: np.ndarray) -> np.ndarray:
27
  """
28
- Upscales the image using Google Gemini's image generation API.
29
- Sends the cropped card image to Gemini with a prompt to upscale it,
30
- then returns the AI-enhanced result.
31
-
32
  Handles both BGR and BGRA (transparent) images.
33
- Falls back to local upscaling if Gemini API fails.
34
  """
35
  has_alpha = len(img.shape) == 3 and img.shape[2] == 4
36
-
37
  if has_alpha:
38
  bgr = img[:, :, :3]
39
  alpha = img[:, :, 3]
40
  else:
41
  bgr = img
42
  alpha = None
43
-
44
  try:
45
- # Convert BGR (OpenCV) to RGB (PIL)
46
- rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
47
- pil_image = Image.fromarray(rgb)
48
-
49
- # Call Gemini API to upscale
50
- upscaled_pil = _gemini_upscale(pil_image)
51
-
52
- # Convert back to OpenCV BGR
53
- upscaled_rgb = np.array(upscaled_pil)
54
- upscaled_bgr = cv2.cvtColor(upscaled_rgb, cv2.COLOR_RGB2BGR)
55
-
56
  if alpha is not None:
57
- # Resize alpha to match the upscaled image
58
- h, w = upscaled_bgr.shape[:2]
59
- upscaled_alpha = cv2.resize(alpha, (w, h), interpolation=cv2.INTER_LANCZOS4)
60
  _, upscaled_alpha = cv2.threshold(upscaled_alpha, 127, 255, cv2.THRESH_BINARY)
61
-
62
  return cv2.merge((
63
  upscaled_bgr[:, :, 0],
64
  upscaled_bgr[:, :, 1],
65
  upscaled_bgr[:, :, 2],
66
- upscaled_alpha
67
  ))
68
- else:
69
- return upscaled_bgr
70
-
71
  except Exception as e:
72
- print(f"Gemini upscale failed: {e}")
73
- print("Falling back to local upscaling...")
74
  return _local_fallback_upscale(img)
75
 
76
 
77
- def _gemini_upscale(pil_image: Image.Image) -> Image.Image:
78
- """
79
- Uses the Gemini API to upscale/enhance an image.
80
- """
81
- client = get_client()
82
-
83
- response = client.models.generate_content(
84
- model="gemini-2.0-flash-exp",
85
- contents=[
86
- "Upscale this credit card image to high resolution. "
87
- "Make the text sharp, crisp, and readable. "
88
- "Preserve all colors, logos, textures, and details exactly. "
89
- "Do not add any watermarks, borders, or extra elements. "
90
- "Do not change the content of the image in any way. "
91
- "Output only the enhanced image.",
92
- pil_image,
93
- ],
94
- config=types.GenerateContentConfig(
95
- response_modalities=["IMAGE", "TEXT"],
96
- ),
97
- )
98
-
99
- # Extract the image from the response
100
- for part in response.candidates[0].content.parts:
101
- if part.inline_data is not None:
102
- img_bytes = part.inline_data.data
103
- return Image.open(io.BytesIO(img_bytes))
104
-
105
- raise ValueError("Gemini did not return an image in the response")
106
-
107
-
108
  def _local_fallback_upscale(img: np.ndarray) -> np.ndarray:
109
  """
110
- Fallback: local multi-pass Lanczos + sharpening if Gemini API is unavailable.
111
  """
112
  has_alpha = len(img.shape) == 3 and img.shape[2] == 4
113
-
114
  if has_alpha:
115
  bgr = img[:, :, :3]
116
  alpha = img[:, :, 3]
@@ -121,15 +165,15 @@ def _local_fallback_upscale(img: np.ndarray) -> np.ndarray:
121
  h, w = bgr.shape[:2]
122
  upscaled = cv2.resize(bgr, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
123
  upscaled = cv2.bilateralFilter(upscaled, d=5, sigmaColor=40, sigmaSpace=40)
124
-
125
  # Unsharp mask
126
  blurred = cv2.GaussianBlur(upscaled, (0, 0), 2.0)
127
  upscaled = cv2.addWeighted(upscaled, 2.0, blurred, -1.0, 0)
128
-
129
  if alpha is not None:
130
  uh, uw = upscaled.shape[:2]
131
  upscaled_alpha = cv2.resize(alpha, (uw, uh), interpolation=cv2.INTER_LANCZOS4)
132
  _, upscaled_alpha = cv2.threshold(upscaled_alpha, 127, 255, cv2.THRESH_BINARY)
133
- return cv2.merge((upscaled[:,:,0], upscaled[:,:,1], upscaled[:,:,2], upscaled_alpha))
134
-
135
  return upscaled
 
1
  import cv2
2
  import numpy as np
3
  import os
4
+ import urllib.request
 
 
 
 
 
5
 
6
+ # ─── Configuration ───────────────────────────────────────────────
7
+ MODEL_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "weights")
8
+ MODEL_FILENAME = "realesrgan_x4plus.onnx"
9
+ MODEL_PATH = os.path.join(MODEL_DIR, MODEL_FILENAME)
10
+ MODEL_URL = (
11
+ "https://huggingface.co/Qualcomm/Real-ESRGAN-x4plus/resolve/main/"
12
+ "Real-ESRGAN-x4plus.onnx"
13
+ )
14
+ SCALE_FACTOR = 4
15
+ TILE_SIZE = 256 # Process in tiles to limit memory usage
16
+ TILE_OVERLAP = 16 # Overlap between tiles for seamless stitching
17
 
18
+ # Lazy-loaded ONNX session
19
+ _session = None
20
 
21
+
22
+ def _ensure_model():
23
+ """Download the Real-ESRGAN ONNX model if it doesn't exist locally."""
24
+ if os.path.exists(MODEL_PATH):
25
+ return
26
+ os.makedirs(MODEL_DIR, exist_ok=True)
27
+ print(f"Downloading Real-ESRGAN x4plus model to {MODEL_PATH} ...")
28
+ print("(This is a one-time download, ~67 MB)")
29
+ urllib.request.urlretrieve(MODEL_URL, MODEL_PATH)
30
+ print("Download complete.")
31
+
32
+
33
+ def _get_session():
34
+ """Lazily initialize the ONNX Runtime inference session."""
35
+ global _session
36
+ if _session is None:
37
+ import onnxruntime as ort
38
+
39
+ ort.set_default_logger_severity(3) # Suppress verbose logs
40
+ _ensure_model()
41
+ opts = ort.SessionOptions()
42
+ opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
43
+ _session = ort.InferenceSession(
44
+ MODEL_PATH,
45
+ sess_options=opts,
46
+ providers=["CPUExecutionProvider"],
47
+ )
48
+ return _session
49
+
50
+
51
+ def _run_esrgan_tile(session, tile_bgr: np.ndarray) -> np.ndarray:
52
+ """
53
+ Run a single BGR tile through the Real-ESRGAN ONNX model.
54
+ Input: uint8 BGR HWC → Output: uint8 BGR HWC (4× larger)
55
+ """
56
+ # BGR → RGB, HWC → CHW, normalise to [0,1]
57
+ rgb = cv2.cvtColor(tile_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
58
+ tensor = np.expand_dims(rgb.transpose(2, 0, 1), axis=0) # 1×3×H×W
59
+
60
+ input_name = session.get_inputs()[0].name
61
+ result = session.run(None, {input_name: tensor})[0][0] # 3×(4H)×(4W)
62
+
63
+ # CHW → HWC, clip, convert back to BGR uint8
64
+ out_rgb = (result.transpose(1, 2, 0) * 255.0).clip(0, 255).astype(np.uint8)
65
+ return cv2.cvtColor(out_rgb, cv2.COLOR_RGB2BGR)
66
+
67
+
68
+ def _upscale_tiled(session, img_bgr: np.ndarray) -> np.ndarray:
69
+ """
70
+ Upscale a full BGR image using tiled inference with overlap blending.
71
+ This prevents OOM on large images while avoiding visible seams.
72
+ """
73
+ h, w = img_bgr.shape[:2]
74
+ sf = SCALE_FACTOR
75
+
76
+ # Pad image so dimensions are divisible by tile_size
77
+ pad_h = (TILE_SIZE - h % TILE_SIZE) % TILE_SIZE
78
+ pad_w = (TILE_SIZE - w % TILE_SIZE) % TILE_SIZE
79
+ padded = cv2.copyMakeBorder(img_bgr, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101)
80
+ ph, pw = padded.shape[:2]
81
+
82
+ # Output canvas
83
+ out_h, out_w = ph * sf, pw * sf
84
+ output = np.zeros((out_h, out_w, 3), dtype=np.float64)
85
+ weight = np.zeros((out_h, out_w, 1), dtype=np.float64)
86
+
87
+ # Iterate over tiles with overlap
88
+ step = TILE_SIZE - TILE_OVERLAP
89
+ for y in range(0, ph, step):
90
+ for x in range(0, pw, step):
91
+ # Clamp tile boundaries
92
+ ty = min(y, ph - TILE_SIZE)
93
+ tx = min(x, pw - TILE_SIZE)
94
+ tile = padded[ty : ty + TILE_SIZE, tx : tx + TILE_SIZE]
95
+
96
+ # Run inference
97
+ upscaled_tile = _run_esrgan_tile(session, tile)
98
+
99
+ # Output coordinates
100
+ oy, ox = ty * sf, tx * sf
101
+ th, tw = upscaled_tile.shape[:2]
102
+
103
+ # Accumulate with simple averaging (overlap regions get averaged)
104
+ output[oy : oy + th, ox : ox + tw] += upscaled_tile.astype(np.float64)
105
+ weight[oy : oy + th, ox : ox + tw] += 1.0
106
+
107
+ # Average overlapping regions
108
+ weight = np.maximum(weight, 1.0)
109
+ output = (output / weight).clip(0, 255).astype(np.uint8)
110
+
111
+ # Remove padding from output
112
+ return output[: h * sf, : w * sf]
113
 
114
 
115
  def upscale_image(img: np.ndarray) -> np.ndarray:
116
  """
117
+ Upscale an image using Real-ESRGAN via ONNX Runtime.
 
 
 
118
  Handles both BGR and BGRA (transparent) images.
119
+ Falls back to local Lanczos upscaling if ONNX inference fails.
120
  """
121
  has_alpha = len(img.shape) == 3 and img.shape[2] == 4
122
+
123
  if has_alpha:
124
  bgr = img[:, :, :3]
125
  alpha = img[:, :, 3]
126
  else:
127
  bgr = img
128
  alpha = None
129
+
130
  try:
131
+ session = _get_session()
132
+ upscaled_bgr = _upscale_tiled(session, bgr)
133
+
 
 
 
 
 
 
 
 
134
  if alpha is not None:
135
+ uh, uw = upscaled_bgr.shape[:2]
136
+ upscaled_alpha = cv2.resize(alpha, (uw, uh), interpolation=cv2.INTER_LANCZOS4)
 
137
  _, upscaled_alpha = cv2.threshold(upscaled_alpha, 127, 255, cv2.THRESH_BINARY)
 
138
  return cv2.merge((
139
  upscaled_bgr[:, :, 0],
140
  upscaled_bgr[:, :, 1],
141
  upscaled_bgr[:, :, 2],
142
+ upscaled_alpha,
143
  ))
144
+ return upscaled_bgr
145
+
 
146
  except Exception as e:
147
+ print(f"Real-ESRGAN upscale failed: {e}")
148
+ print("Falling back to local Lanczos upscaling...")
149
  return _local_fallback_upscale(img)
150
 
151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  def _local_fallback_upscale(img: np.ndarray) -> np.ndarray:
153
  """
154
+ Fallback: local multi-pass Lanczos + sharpening if ONNX is unavailable.
155
  """
156
  has_alpha = len(img.shape) == 3 and img.shape[2] == 4
157
+
158
  if has_alpha:
159
  bgr = img[:, :, :3]
160
  alpha = img[:, :, 3]
 
165
  h, w = bgr.shape[:2]
166
  upscaled = cv2.resize(bgr, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
167
  upscaled = cv2.bilateralFilter(upscaled, d=5, sigmaColor=40, sigmaSpace=40)
168
+
169
  # Unsharp mask
170
  blurred = cv2.GaussianBlur(upscaled, (0, 0), 2.0)
171
  upscaled = cv2.addWeighted(upscaled, 2.0, blurred, -1.0, 0)
172
+
173
  if alpha is not None:
174
  uh, uw = upscaled.shape[:2]
175
  upscaled_alpha = cv2.resize(alpha, (uw, uh), interpolation=cv2.INTER_LANCZOS4)
176
  _, upscaled_alpha = cv2.threshold(upscaled_alpha, 127, 255, cv2.THRESH_BINARY)
177
+ return cv2.merge((upscaled[:, :, 0], upscaled[:, :, 1], upscaled[:, :, 2], upscaled_alpha))
178
+
179
  return upscaled