Spaces:
No application file
No application file
| import gradio as gr | |
| import os | |
| import sys | |
| import tempfile | |
| import numpy as np | |
| from PIL import Image | |
| # Add parent directory to path to import the caption generation function | |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) | |
| # Import the caption generation function directly | |
| try: | |
| # Try importing as if running from gradio folder | |
| from run_caption import generate_caption as generate_caption_backend | |
| except ImportError: | |
| # Fall back to full path if running from parent directory | |
| from gradio.run_caption import generate_caption as generate_caption_backend | |
| def generate_caption_wrapper(image, epsilon, sparsity, attack_algo, num_iters): | |
| """ | |
| Wrapper for caption generation that interfaces with Gradio UI. | |
| Args: | |
| image: The uploaded image from Gradio | |
| epsilon: Max perturbation value | |
| sparsity: Sparsity parameter for SAIF | |
| attack_algo: Attack algorithm (APGD or SAIF) | |
| num_iters: Number of iterations | |
| Returns: | |
| tuple: (original_caption, adversarial_caption, original_image, adversarial_image, perturbation_image) | |
| """ | |
| if image is None: | |
| return "Please upload an image first.", "", None, None, None | |
| try: | |
| # Save the uploaded image to a temporary file | |
| with tempfile.NamedTemporaryFile(mode='wb', suffix='.jpg', delete=False) as tmp_file: | |
| tmp_image_path = tmp_file.name | |
| if isinstance(image, np.ndarray): | |
| img = Image.fromarray(image) | |
| img.save(tmp_image_path) | |
| else: | |
| image.save(tmp_image_path) | |
| # Call the backend function directly | |
| result_dict = generate_caption_backend( | |
| image_path=tmp_image_path, | |
| epsilon=epsilon, | |
| sparsity=sparsity, | |
| attack_algo=attack_algo, | |
| num_iters=num_iters, | |
| model_name="open_flamingo", | |
| num_shots=0, | |
| targeted=False | |
| ) | |
| # Clean up temporary file | |
| try: | |
| os.unlink(tmp_image_path) | |
| except: | |
| pass | |
| # Extract results | |
| original = result_dict.get('original_caption', '').strip() | |
| adversarial = result_dict.get('adversarial_caption', '').strip() | |
| orig_img_path = result_dict.get('original_image_path') | |
| adv_img_path = result_dict.get('adversarial_image_path') | |
| pert_img_path = result_dict.get('perturbation_image_path') | |
| orig_image = None | |
| adv_image = None | |
| pert_image = None | |
| if orig_img_path and os.path.exists(orig_img_path): | |
| orig_image = np.array(Image.open(orig_img_path)) | |
| try: | |
| os.unlink(orig_img_path) | |
| except: | |
| pass | |
| if adv_img_path and os.path.exists(adv_img_path): | |
| adv_image = np.array(Image.open(adv_img_path)) | |
| try: | |
| os.unlink(adv_img_path) | |
| except: | |
| pass | |
| if pert_img_path and os.path.exists(pert_img_path): | |
| pert_image = np.array(Image.open(pert_img_path)) | |
| try: | |
| os.unlink(pert_img_path) | |
| except: | |
| pass | |
| return original, adversarial, orig_image, adv_image, pert_image | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"Error: {str(e)}\n{traceback.format_exc()}" | |
| print(error_msg, flush=True) | |
| return f"Error: {str(e)}", "", None, None, None | |
| # Create the Gradio interface | |
| with gr.Blocks(title="Image Captioning") as demo: | |
| gr.Markdown("# Evaluating Robustness of Multimodal Models Against Adversarial Perturbations") | |
| gr.Markdown("Upload an image to generate the adversarial image and caption using the APGD/SAIF algorithm.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image( | |
| label="Upload Image", | |
| type="numpy" | |
| ) | |
| attack_algo = gr.Dropdown( | |
| choices=["APGD", "SAIF"], | |
| value="APGD", | |
| label="Adversarial Attack Algorithm", | |
| interactive=True | |
| ) | |
| epsilon = gr.Slider( | |
| minimum=1, maximum=255, value=8, step=1, interactive=True, | |
| label="Epsilon (max perturbation, 0-255 scale)" | |
| ) | |
| sparsity = gr.Slider( | |
| minimum=0, maximum=10000, value=0, step=100, interactive=True, | |
| label="Sparsity (L1 norm of the perturbation, for SAIF only)" | |
| ) | |
| num_iters = gr.Slider( | |
| minimum=1, maximum=100, value=8, step=1, interactive=True, | |
| label="Number of Iterations" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| generate_btn = gr.Button("Generate Captions", variant="primary") | |
| with gr.Row(): | |
| with gr.Column(): | |
| orig_image_output = gr.Image(label="Original Image") | |
| orig_caption_output = gr.Textbox( | |
| label="Generated Original Caption", | |
| lines=5, | |
| placeholder="Caption will appear here..." | |
| ) | |
| with gr.Column(): | |
| pert_image_output = gr.Image(label="Perturbation (10x magnified)") | |
| with gr.Column(): | |
| adv_image_output = gr.Image(label="Adversarial Image") | |
| adv_caption_output = gr.Textbox( | |
| label="Generated Adversarial Caption", | |
| lines=5, | |
| placeholder="Caption will appear here..." | |
| ) | |
| # Set up the button click event | |
| generate_btn.click( | |
| fn=generate_caption_wrapper, | |
| inputs=[image_input, epsilon, sparsity, attack_algo, num_iters], | |
| outputs=[orig_caption_output, adv_caption_output, orig_image_output, adv_image_output, pert_image_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=True, | |
| debug=True, | |
| show_error=True | |
| ) | |