Update app.py
Browse files
app.py
CHANGED
|
@@ -19,48 +19,42 @@ load_dotenv()
|
|
| 19 |
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
|
| 20 |
|
| 21 |
|
| 22 |
-
def
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
if not isinstance(image, Image.Image):
|
| 25 |
-
raise ValueError("Input must be a PIL Image")
|
| 26 |
-
|
| 27 |
-
width, height = image.size
|
| 28 |
-
if width > max_width or height > max_height:
|
| 29 |
-
aspect_ratio = width / height
|
| 30 |
-
if aspect_ratio > 1:
|
| 31 |
-
new_width = max_width
|
| 32 |
-
new_height = int(new_width / aspect_ratio)
|
| 33 |
-
else:
|
| 34 |
-
new_height = max_height
|
| 35 |
-
new_width = int(new_height * aspect_ratio)
|
| 36 |
-
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
| 37 |
|
| 38 |
buffered = io.BytesIO()
|
| 39 |
-
|
| 40 |
-
image.save(buffered, format="JPEG", quality=quality)
|
| 41 |
-
buffered.seek(0)
|
| 42 |
return base64.b64encode(buffered.getvalue()).decode('utf-8')
|
| 43 |
|
|
|
|
| 44 |
def analyze_image(image):
|
| 45 |
client = OpenAI(api_key=OPENAI_API_KEY)
|
|
|
|
| 46 |
|
| 47 |
-
#
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
# Build the list-of-dicts prompt
|
| 51 |
-
prompt_dict = [
|
| 52 |
{
|
| 53 |
"type": "text",
|
| 54 |
-
"text": """Your task is to determine if the image is surprising or not.
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
{
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
}
|
| 63 |
-
"""
|
| 64 |
},
|
| 65 |
{
|
| 66 |
"type": "image_url",
|
|
@@ -70,27 +64,29 @@ def analyze_image(image):
|
|
| 70 |
}
|
| 71 |
]
|
| 72 |
|
| 73 |
-
|
| 74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
-
# Send request
|
| 77 |
response = client.chat.completions.create(
|
| 78 |
-
model="gpt-4o-mini",
|
| 79 |
-
messages=
|
| 80 |
-
{
|
| 81 |
-
"role": "user",
|
| 82 |
-
"content": json_prompt
|
| 83 |
-
}
|
| 84 |
-
],
|
| 85 |
max_tokens=100,
|
| 86 |
temperature=0.1,
|
| 87 |
-
response_format={
|
|
|
|
|
|
|
| 88 |
)
|
| 89 |
|
| 90 |
return response.choices[0].message.content
|
| 91 |
|
| 92 |
|
| 93 |
-
|
| 94 |
def show_mask(mask, ax, random_color=False):
|
| 95 |
if random_color:
|
| 96 |
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
|
@@ -114,7 +110,7 @@ def process_image_detection(image, target_label, surprise_rating):
|
|
| 114 |
original_size = image.size
|
| 115 |
|
| 116 |
# Calculate relative font size based on image dimensions
|
| 117 |
-
base_fontsize = min(original_size) / 40 # Adjust this divisor
|
| 118 |
|
| 119 |
owlv2_processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16")
|
| 120 |
owlv2_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16").to(device)
|
|
@@ -137,6 +133,7 @@ def process_image_detection(image, target_label, surprise_rating):
|
|
| 137 |
|
| 138 |
ax = plt.Axes(fig, [0., 0., 1., 1.])
|
| 139 |
fig.add_axes(ax)
|
|
|
|
| 140 |
plt.imshow(image)
|
| 141 |
|
| 142 |
scores = results["scores"]
|
|
@@ -165,7 +162,7 @@ def process_image_detection(image, target_label, surprise_rating):
|
|
| 165 |
mask = masks[0].numpy() if isinstance(masks[0], torch.Tensor) else masks[0]
|
| 166 |
show_mask(mask, ax=ax)
|
| 167 |
|
| 168 |
-
# Draw rectangle
|
| 169 |
rect = patches.Rectangle(
|
| 170 |
(box[0], box[1]),
|
| 171 |
box[2] - box[0],
|
|
@@ -176,7 +173,7 @@ def process_image_detection(image, target_label, surprise_rating):
|
|
| 176 |
)
|
| 177 |
ax.add_patch(rect)
|
| 178 |
|
| 179 |
-
#
|
| 180 |
plt.text(
|
| 181 |
box[0], box[1] - base_fontsize,
|
| 182 |
f'{max_score:.2f}',
|
|
@@ -186,7 +183,7 @@ def process_image_detection(image, target_label, surprise_rating):
|
|
| 186 |
bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', pad=2)
|
| 187 |
)
|
| 188 |
|
| 189 |
-
#
|
| 190 |
plt.text(
|
| 191 |
box[2] + base_fontsize / 2, box[1],
|
| 192 |
f'Unexpected (Rating: {surprise_rating}/5)\n{target_label}',
|
|
@@ -199,20 +196,17 @@ def process_image_detection(image, target_label, surprise_rating):
|
|
| 199 |
|
| 200 |
plt.axis('off')
|
| 201 |
|
| 202 |
-
# Save figure to buffer
|
| 203 |
buf = io.BytesIO()
|
| 204 |
-
plt.savefig(
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
metadata={'dpi': original_dpi}
|
| 211 |
-
)
|
| 212 |
buf.seek(0)
|
| 213 |
plt.close()
|
| 214 |
|
| 215 |
-
#
|
| 216 |
output_image = Image.open(buf)
|
| 217 |
output_image = output_image.resize(original_size, Image.Resampling.LANCZOS)
|
| 218 |
|
|
@@ -233,17 +227,16 @@ def process_and_analyze(image):
|
|
| 233 |
try:
|
| 234 |
# Handle different input types
|
| 235 |
if isinstance(image, tuple):
|
| 236 |
-
image = image[0]
|
| 237 |
if isinstance(image, np.ndarray):
|
| 238 |
image = Image.fromarray(image)
|
| 239 |
if not isinstance(image, Image.Image):
|
| 240 |
raise ValueError("Invalid image format")
|
| 241 |
|
| 242 |
-
# Analyze image
|
| 243 |
gpt_response = analyze_image(image)
|
| 244 |
response_data = json.loads(gpt_response)
|
| 245 |
|
| 246 |
-
# If surprising, try to detect the element
|
| 247 |
if response_data["label"].lower() == "surprising" and response_data["element"].lower() != "na":
|
| 248 |
result_buf = process_image_detection(image, response_data["element"], response_data["rating"])
|
| 249 |
result_image = Image.open(result_buf)
|
|
@@ -254,7 +247,6 @@ def process_and_analyze(image):
|
|
| 254 |
)
|
| 255 |
return result_image, analysis_text
|
| 256 |
else:
|
| 257 |
-
# If not surprising or element=NA
|
| 258 |
return image, "Not Surprising"
|
| 259 |
|
| 260 |
except Exception as e:
|
|
|
|
| 19 |
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
|
| 20 |
|
| 21 |
|
| 22 |
+
def encode_image_to_base64(image):
|
| 23 |
+
# If image is a tuple (as sometimes provided by Gradio), take the first element
|
| 24 |
+
if isinstance(image, tuple):
|
| 25 |
+
image = image[0]
|
| 26 |
+
|
| 27 |
+
# If image is a numpy array, convert to PIL Image
|
| 28 |
+
if isinstance(image, np.ndarray):
|
| 29 |
+
image = Image.fromarray(image)
|
| 30 |
+
|
| 31 |
+
# Ensure image is in PIL Image format
|
| 32 |
if not isinstance(image, Image.Image):
|
| 33 |
+
raise ValueError("Input must be a PIL Image, numpy array, or tuple containing an image")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
buffered = io.BytesIO()
|
| 36 |
+
image.save(buffered, format="PNG")
|
|
|
|
|
|
|
| 37 |
return base64.b64encode(buffered.getvalue()).decode('utf-8')
|
| 38 |
|
| 39 |
+
|
| 40 |
def analyze_image(image):
|
| 41 |
client = OpenAI(api_key=OPENAI_API_KEY)
|
| 42 |
+
base64_image = encode_image_to_base64(image)
|
| 43 |
|
| 44 |
+
# --- MINIMAL FIX START ---
|
| 45 |
+
# We build a Python list of dicts, then JSON-encode it:
|
| 46 |
+
prompt_list = [
|
|
|
|
|
|
|
| 47 |
{
|
| 48 |
"type": "text",
|
| 49 |
+
"text": """Your task is to determine if the image is surprising or not surprising.
|
| 50 |
+
if the image is surprising, determine which element, figure or object in the image is making the image surprising and write it only in one sentence with no more then 6 words, otherwise, write 'NA'.
|
| 51 |
+
Also rate how surprising the image is on a scale of 1-5, where 1 is not surprising at all and 5 is highly surprising.
|
| 52 |
+
Provide the response as a JSON with the following structure:
|
| 53 |
{
|
| 54 |
+
"label": "[surprising OR not surprising]",
|
| 55 |
+
"element": "[element]",
|
| 56 |
+
"rating": [1-5]
|
| 57 |
+
}"""
|
|
|
|
| 58 |
},
|
| 59 |
{
|
| 60 |
"type": "image_url",
|
|
|
|
| 64 |
}
|
| 65 |
]
|
| 66 |
|
| 67 |
+
prompt_json = json.dumps(prompt_list)
|
| 68 |
+
|
| 69 |
+
messages = [
|
| 70 |
+
{
|
| 71 |
+
"role": "user",
|
| 72 |
+
"content": prompt_json # content must be a single string
|
| 73 |
+
}
|
| 74 |
+
]
|
| 75 |
+
# --- MINIMAL FIX END ---
|
| 76 |
|
|
|
|
| 77 |
response = client.chat.completions.create(
|
| 78 |
+
model="gpt-4o-mini", # or whichever model you have access to
|
| 79 |
+
messages=messages,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
max_tokens=100,
|
| 81 |
temperature=0.1,
|
| 82 |
+
response_format={
|
| 83 |
+
"type": "json_object"
|
| 84 |
+
}
|
| 85 |
)
|
| 86 |
|
| 87 |
return response.choices[0].message.content
|
| 88 |
|
| 89 |
|
|
|
|
| 90 |
def show_mask(mask, ax, random_color=False):
|
| 91 |
if random_color:
|
| 92 |
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
|
|
|
| 110 |
original_size = image.size
|
| 111 |
|
| 112 |
# Calculate relative font size based on image dimensions
|
| 113 |
+
base_fontsize = min(original_size) / 40 # Adjust this divisor to change overall font size
|
| 114 |
|
| 115 |
owlv2_processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16")
|
| 116 |
owlv2_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16").to(device)
|
|
|
|
| 133 |
|
| 134 |
ax = plt.Axes(fig, [0., 0., 1., 1.])
|
| 135 |
fig.add_axes(ax)
|
| 136 |
+
|
| 137 |
plt.imshow(image)
|
| 138 |
|
| 139 |
scores = results["scores"]
|
|
|
|
| 162 |
mask = masks[0].numpy() if isinstance(masks[0], torch.Tensor) else masks[0]
|
| 163 |
show_mask(mask, ax=ax)
|
| 164 |
|
| 165 |
+
# Draw rectangle with increased line width
|
| 166 |
rect = patches.Rectangle(
|
| 167 |
(box[0], box[1]),
|
| 168 |
box[2] - box[0],
|
|
|
|
| 173 |
)
|
| 174 |
ax.add_patch(rect)
|
| 175 |
|
| 176 |
+
# Add confidence score with improved visibility
|
| 177 |
plt.text(
|
| 178 |
box[0], box[1] - base_fontsize,
|
| 179 |
f'{max_score:.2f}',
|
|
|
|
| 183 |
bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', pad=2)
|
| 184 |
)
|
| 185 |
|
| 186 |
+
# Add label and rating with improved visibility
|
| 187 |
plt.text(
|
| 188 |
box[2] + base_fontsize / 2, box[1],
|
| 189 |
f'Unexpected (Rating: {surprise_rating}/5)\n{target_label}',
|
|
|
|
| 196 |
|
| 197 |
plt.axis('off')
|
| 198 |
|
|
|
|
| 199 |
buf = io.BytesIO()
|
| 200 |
+
plt.savefig(buf,
|
| 201 |
+
format='png',
|
| 202 |
+
dpi=dpi,
|
| 203 |
+
bbox_inches='tight',
|
| 204 |
+
pad_inches=0,
|
| 205 |
+
metadata={'dpi': original_dpi})
|
|
|
|
|
|
|
| 206 |
buf.seek(0)
|
| 207 |
plt.close()
|
| 208 |
|
| 209 |
+
# Process final image
|
| 210 |
output_image = Image.open(buf)
|
| 211 |
output_image = output_image.resize(original_size, Image.Resampling.LANCZOS)
|
| 212 |
|
|
|
|
| 227 |
try:
|
| 228 |
# Handle different input types
|
| 229 |
if isinstance(image, tuple):
|
| 230 |
+
image = image[0] # Take the first element if it's a tuple
|
| 231 |
if isinstance(image, np.ndarray):
|
| 232 |
image = Image.fromarray(image)
|
| 233 |
if not isinstance(image, Image.Image):
|
| 234 |
raise ValueError("Invalid image format")
|
| 235 |
|
| 236 |
+
# Analyze image
|
| 237 |
gpt_response = analyze_image(image)
|
| 238 |
response_data = json.loads(gpt_response)
|
| 239 |
|
|
|
|
| 240 |
if response_data["label"].lower() == "surprising" and response_data["element"].lower() != "na":
|
| 241 |
result_buf = process_image_detection(image, response_data["element"], response_data["rating"])
|
| 242 |
result_image = Image.open(result_buf)
|
|
|
|
| 247 |
)
|
| 248 |
return result_image, analysis_text
|
| 249 |
else:
|
|
|
|
| 250 |
return image, "Not Surprising"
|
| 251 |
|
| 252 |
except Exception as e:
|