RobustMMFM2 / gradio /gradio_app.py
KC123hello's picture
Upload files
308f265 verified
import gradio as gr
import os
import sys
import tempfile
import numpy as np
from PIL import Image
# Add parent directory to path to import the caption generation function
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
# Import the caption generation function directly
try:
# Try importing as if running from gradio folder
from run_caption import generate_caption as generate_caption_backend
except ImportError:
# Fall back to full path if running from parent directory
from gradio.run_caption import generate_caption as generate_caption_backend
def generate_caption_wrapper(image, epsilon, sparsity, attack_algo, num_iters):
"""
Wrapper for caption generation that interfaces with Gradio UI.
Args:
image: The uploaded image from Gradio
epsilon: Max perturbation value
sparsity: Sparsity parameter for SAIF
attack_algo: Attack algorithm (APGD or SAIF)
num_iters: Number of iterations
Returns:
tuple: (original_caption, adversarial_caption, original_image, adversarial_image, perturbation_image)
"""
if image is None:
return "Please upload an image first.", "", None, None, None
try:
# Save the uploaded image to a temporary file
with tempfile.NamedTemporaryFile(mode='wb', suffix='.jpg', delete=False) as tmp_file:
tmp_image_path = tmp_file.name
if isinstance(image, np.ndarray):
img = Image.fromarray(image)
img.save(tmp_image_path)
else:
image.save(tmp_image_path)
# Call the backend function directly
result_dict = generate_caption_backend(
image_path=tmp_image_path,
epsilon=epsilon,
sparsity=sparsity,
attack_algo=attack_algo,
num_iters=num_iters,
model_name="open_flamingo",
num_shots=0,
targeted=False
)
# Clean up temporary file
try:
os.unlink(tmp_image_path)
except:
pass
# Extract results
original = result_dict.get('original_caption', '').strip()
adversarial = result_dict.get('adversarial_caption', '').strip()
orig_img_path = result_dict.get('original_image_path')
adv_img_path = result_dict.get('adversarial_image_path')
pert_img_path = result_dict.get('perturbation_image_path')
orig_image = None
adv_image = None
pert_image = None
if orig_img_path and os.path.exists(orig_img_path):
orig_image = np.array(Image.open(orig_img_path))
try:
os.unlink(orig_img_path)
except:
pass
if adv_img_path and os.path.exists(adv_img_path):
adv_image = np.array(Image.open(adv_img_path))
try:
os.unlink(adv_img_path)
except:
pass
if pert_img_path and os.path.exists(pert_img_path):
pert_image = np.array(Image.open(pert_img_path))
try:
os.unlink(pert_img_path)
except:
pass
return original, adversarial, orig_image, adv_image, pert_image
except Exception as e:
import traceback
error_msg = f"Error: {str(e)}\n{traceback.format_exc()}"
print(error_msg, flush=True)
return f"Error: {str(e)}", "", None, None, None
# Create the Gradio interface
with gr.Blocks(title="Image Captioning") as demo:
gr.Markdown("# Evaluating Robustness of Multimodal Models Against Adversarial Perturbations")
gr.Markdown("Upload an image to generate the adversarial image and caption using the APGD/SAIF algorithm.")
with gr.Row():
with gr.Column():
image_input = gr.Image(
label="Upload Image",
type="numpy"
)
attack_algo = gr.Dropdown(
choices=["APGD", "SAIF"],
value="APGD",
label="Adversarial Attack Algorithm",
interactive=True
)
epsilon = gr.Slider(
minimum=1, maximum=255, value=8, step=1, interactive=True,
label="Epsilon (max perturbation, 0-255 scale)"
)
sparsity = gr.Slider(
minimum=0, maximum=10000, value=0, step=100, interactive=True,
label="Sparsity (L1 norm of the perturbation, for SAIF only)"
)
num_iters = gr.Slider(
minimum=1, maximum=100, value=8, step=1, interactive=True,
label="Number of Iterations"
)
with gr.Row():
with gr.Column():
generate_btn = gr.Button("Generate Captions", variant="primary")
with gr.Row():
with gr.Column():
orig_image_output = gr.Image(label="Original Image")
orig_caption_output = gr.Textbox(
label="Generated Original Caption",
lines=5,
placeholder="Caption will appear here..."
)
with gr.Column():
pert_image_output = gr.Image(label="Perturbation (10x magnified)")
with gr.Column():
adv_image_output = gr.Image(label="Adversarial Image")
adv_caption_output = gr.Textbox(
label="Generated Adversarial Caption",
lines=5,
placeholder="Caption will appear here..."
)
# Set up the button click event
generate_btn.click(
fn=generate_caption_wrapper,
inputs=[image_input, epsilon, sparsity, attack_algo, num_iters],
outputs=[orig_caption_output, adv_caption_output, orig_image_output, adv_image_output, pert_image_output]
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True,
debug=True,
show_error=True
)