Devashishraghav commited on
Commit
821a664
·
verified ·
1 Parent(s): 080728c

Upload processor/upscale.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. processor/upscale.py +135 -0
processor/upscale.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import os
4
+ import io
5
+ import base64
6
+ from PIL import Image
7
+ from dotenv import load_dotenv
8
+ from google import genai
9
+ from google.genai import types
10
+
11
+ load_dotenv()
12
+
13
+ # Initialize the Gemini client
14
+ _client = None
15
+
16
+ def get_client():
17
+ global _client
18
+ if _client is None:
19
+ api_key = os.getenv("GEMINI_API_KEY")
20
+ if not api_key or api_key == "your_api_key_here":
21
+ raise ValueError("GEMINI_API_KEY not set. Please add it to your .env file.")
22
+ _client = genai.Client(api_key=api_key)
23
+ return _client
24
+
25
+
26
+ def upscale_image(img: np.ndarray) -> np.ndarray:
27
+ """
28
+ Upscales the image using Google Gemini's image generation API.
29
+ Sends the cropped card image to Gemini with a prompt to upscale it,
30
+ then returns the AI-enhanced result.
31
+
32
+ Handles both BGR and BGRA (transparent) images.
33
+ Falls back to local upscaling if Gemini API fails.
34
+ """
35
+ has_alpha = len(img.shape) == 3 and img.shape[2] == 4
36
+
37
+ if has_alpha:
38
+ bgr = img[:, :, :3]
39
+ alpha = img[:, :, 3]
40
+ else:
41
+ bgr = img
42
+ alpha = None
43
+
44
+ try:
45
+ # Convert BGR (OpenCV) to RGB (PIL)
46
+ rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
47
+ pil_image = Image.fromarray(rgb)
48
+
49
+ # Call Gemini API to upscale
50
+ upscaled_pil = _gemini_upscale(pil_image)
51
+
52
+ # Convert back to OpenCV BGR
53
+ upscaled_rgb = np.array(upscaled_pil)
54
+ upscaled_bgr = cv2.cvtColor(upscaled_rgb, cv2.COLOR_RGB2BGR)
55
+
56
+ if alpha is not None:
57
+ # Resize alpha to match the upscaled image
58
+ h, w = upscaled_bgr.shape[:2]
59
+ upscaled_alpha = cv2.resize(alpha, (w, h), interpolation=cv2.INTER_LANCZOS4)
60
+ _, upscaled_alpha = cv2.threshold(upscaled_alpha, 127, 255, cv2.THRESH_BINARY)
61
+
62
+ return cv2.merge((
63
+ upscaled_bgr[:, :, 0],
64
+ upscaled_bgr[:, :, 1],
65
+ upscaled_bgr[:, :, 2],
66
+ upscaled_alpha
67
+ ))
68
+ else:
69
+ return upscaled_bgr
70
+
71
+ except Exception as e:
72
+ print(f"Gemini upscale failed: {e}")
73
+ print("Falling back to local upscaling...")
74
+ return _local_fallback_upscale(img)
75
+
76
+
77
+ def _gemini_upscale(pil_image: Image.Image) -> Image.Image:
78
+ """
79
+ Uses the Gemini API to upscale/enhance an image.
80
+ """
81
+ client = get_client()
82
+
83
+ response = client.models.generate_content(
84
+ model="gemini-2.0-flash-exp",
85
+ contents=[
86
+ "Upscale this credit card image to high resolution. "
87
+ "Make the text sharp, crisp, and readable. "
88
+ "Preserve all colors, logos, textures, and details exactly. "
89
+ "Do not add any watermarks, borders, or extra elements. "
90
+ "Do not change the content of the image in any way. "
91
+ "Output only the enhanced image.",
92
+ pil_image,
93
+ ],
94
+ config=types.GenerateContentConfig(
95
+ response_modalities=["IMAGE", "TEXT"],
96
+ ),
97
+ )
98
+
99
+ # Extract the image from the response
100
+ for part in response.candidates[0].content.parts:
101
+ if part.inline_data is not None:
102
+ img_bytes = part.inline_data.data
103
+ return Image.open(io.BytesIO(img_bytes))
104
+
105
+ raise ValueError("Gemini did not return an image in the response")
106
+
107
+
108
+ def _local_fallback_upscale(img: np.ndarray) -> np.ndarray:
109
+ """
110
+ Fallback: local multi-pass Lanczos + sharpening if Gemini API is unavailable.
111
+ """
112
+ has_alpha = len(img.shape) == 3 and img.shape[2] == 4
113
+
114
+ if has_alpha:
115
+ bgr = img[:, :, :3]
116
+ alpha = img[:, :, 3]
117
+ else:
118
+ bgr = img
119
+ alpha = None
120
+
121
+ h, w = bgr.shape[:2]
122
+ upscaled = cv2.resize(bgr, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
123
+ upscaled = cv2.bilateralFilter(upscaled, d=5, sigmaColor=40, sigmaSpace=40)
124
+
125
+ # Unsharp mask
126
+ blurred = cv2.GaussianBlur(upscaled, (0, 0), 2.0)
127
+ upscaled = cv2.addWeighted(upscaled, 2.0, blurred, -1.0, 0)
128
+
129
+ if alpha is not None:
130
+ uh, uw = upscaled.shape[:2]
131
+ upscaled_alpha = cv2.resize(alpha, (uw, uh), interpolation=cv2.INTER_LANCZOS4)
132
+ _, upscaled_alpha = cv2.threshold(upscaled_alpha, 127, 255, cv2.THRESH_BINARY)
133
+ return cv2.merge((upscaled[:,:,0], upscaled[:,:,1], upscaled[:,:,2], upscaled_alpha))
134
+
135
+ return upscaled