KC123hello commited on
Commit
c49d4d8
·
verified ·
1 Parent(s): 7365249

Upload 2 files

Browse files
Files changed (2) hide show
  1. gradio/gradio_app.py +186 -0
  2. gradio/run_caption.py +221 -0
gradio/gradio_app.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import subprocess
3
+ import os
4
+ import tempfile
5
+ import json
6
+
7
+ def generate_caption(image, epsilon, sparsity, attack_algo, num_iters):
8
+ """
9
+ Generate caption for the uploaded image using the model in RobustMMFMEnv.
10
+
11
+ Args:
12
+ image: The uploaded image from Gradio
13
+
14
+ Returns:
15
+ tuple: (original_caption, adversarial_caption, original_image, adversarial_image, perturbation_image)
16
+ """
17
+ if image is None:
18
+ return "Please upload an image first.", "", None, None, None
19
+
20
+ try:
21
+ # Save the uploaded image to a temporary file
22
+ with tempfile.NamedTemporaryFile(mode='wb', suffix='.jpg', delete=False) as tmp_file:
23
+ tmp_image_path = tmp_file.name
24
+ # Save the image
25
+ from PIL import Image
26
+ import numpy as np
27
+
28
+ if isinstance(image, np.ndarray):
29
+ img = Image.fromarray(image)
30
+ img.save(tmp_image_path)
31
+ else:
32
+ image.save(tmp_image_path)
33
+
34
+ # Prepare the command to run in RobustMMFMEnv
35
+ # This is a placeholder - you'll need to create the actual script
36
+ conda_env = "RobustMMFMEnv"
37
+ script_path = os.path.join(os.path.dirname(__file__), "run_caption.py")
38
+
39
+ # Run the caption generation script in the RobustMMFMEnv conda environment
40
+ cmd = [
41
+ "conda", "run", "-n", conda_env,
42
+ "python", script_path,
43
+ "--image_path", tmp_image_path,
44
+ "--epsilon", str(epsilon),
45
+ "--num_iters", str(num_iters),
46
+ "--sparsity", str(sparsity),
47
+ "--attack_algo", attack_algo
48
+ ]
49
+
50
+ result = subprocess.run(
51
+ cmd,
52
+ capture_output=True,
53
+ text=True,
54
+ timeout=60 # 60 seconds timeout
55
+ )
56
+
57
+ # Clean up temporary file
58
+ os.unlink(tmp_image_path)
59
+
60
+ if result.returncode == 0:
61
+ # Parse the output
62
+ output = result.stdout.strip()
63
+ #return output if output else "No caption generated."
64
+
65
+ try:
66
+ # Parse the dictionary output
67
+ import ast
68
+ result_dict = ast.literal_eval(output)
69
+
70
+ original = result_dict.get('original_caption', '').strip()
71
+ adversarial = result_dict.get('adversarial_caption', '').strip()
72
+
73
+ orig_img_path = result_dict.get('original_image_path')
74
+ adv_img_path = result_dict.get('adversarial_image_path')
75
+ pert_img_path = result_dict.get('perturbation_image_path')
76
+
77
+ orig_image = None
78
+ adv_image = None
79
+ pert_image = None
80
+
81
+ if orig_img_path and os.path.exists(orig_img_path):
82
+ orig_image = np.array(Image.open(orig_img_path))
83
+ try:
84
+ os.unlink(orig_img_path)
85
+ except:
86
+ pass
87
+
88
+ if adv_img_path and os.path.exists(adv_img_path):
89
+ adv_image = np.array(Image.open(adv_img_path))
90
+ try:
91
+ os.unlink(adv_img_path)
92
+ except:
93
+ pass
94
+
95
+ if pert_img_path and os.path.exists(pert_img_path):
96
+ pert_image = np.array(Image.open(pert_img_path))
97
+ try:
98
+ os.unlink(pert_img_path)
99
+ except:
100
+ pass
101
+
102
+ return original, adversarial, orig_image, adv_image, pert_image # Return 5 values
103
+
104
+ except (ValueError, SyntaxError) as e:
105
+ print(f"Failed to parse output: {e}", flush=True)
106
+ # If parsing fails, try to return raw output
107
+ return f"Parse error: {str(e)}", "", None, None, None
108
+ else:
109
+ error_msg = result.stderr.strip()
110
+ return f"Error generating caption: {error_msg}", "", None, None, None
111
+
112
+ except subprocess.TimeoutExpired:
113
+ return "Error: Caption generation timed out (>60s)", "", None, None, None
114
+ except Exception as e:
115
+ return f"Error: {str(e)}", "", None, None, None
116
+
117
+ # Create the Gradio interface
118
+ with gr.Blocks(title="Image Captioning") as demo:
119
+ gr.Markdown("# Evaluating Robustness of Multimodal Models Against Adversarial Perturbations")
120
+ gr.Markdown("Upload an image to generate the adversarial image and caption using the APGD/SAIF algorithm.")
121
+
122
+ with gr.Row():
123
+ with gr.Column():
124
+ image_input = gr.Image(
125
+ label="Upload Image",
126
+ type="numpy"
127
+ )
128
+
129
+ attack_algo = gr.Dropdown(
130
+ choices=["APGD", "SAIF"],
131
+ value="APGD",
132
+ label="Adversarial Attack Algorithm",
133
+ interactive=True
134
+ )
135
+
136
+ epsilon = gr.Slider(
137
+ minimum=1, maximum=255, value=8, step=1, interactive=True,
138
+ label="Epsilon (max perturbation, 0-255 scale)"
139
+ )
140
+ sparsity = gr.Slider(
141
+ minimum=0, maximum=10000, value=0, step=100, interactive=True,
142
+ label="Sparsity (L1 norm of the perturbation, for SAIF only)"
143
+ )
144
+ num_iters = gr.Slider(
145
+ minimum=1, maximum=100, value=8, step=1, interactive=True,
146
+ label="Number of Iterations"
147
+ )
148
+
149
+ with gr.Row():
150
+ with gr.Column():
151
+ generate_btn = gr.Button("Generate Captions", variant="primary")
152
+
153
+ with gr.Row():
154
+ with gr.Column():
155
+ orig_image_output = gr.Image(label="Original Image")
156
+ orig_caption_output = gr.Textbox(
157
+ label="Generated Original Caption",
158
+ lines=5,
159
+ placeholder="Caption will appear here..."
160
+ )
161
+ with gr.Column():
162
+ pert_image_output = gr.Image(label="Perturbation (10x magnified)")
163
+ with gr.Column():
164
+ adv_image_output = gr.Image(label="Adversarial Image")
165
+ adv_caption_output = gr.Textbox(
166
+ label="Generated Adversarial Caption",
167
+ lines=5,
168
+ placeholder="Caption will appear here..."
169
+ )
170
+
171
+ # Set up the button click event
172
+ generate_btn.click(
173
+ fn=generate_caption,
174
+ inputs=[image_input, epsilon, sparsity, attack_algo, num_iters],
175
+ outputs=[orig_caption_output, adv_caption_output, orig_image_output, adv_image_output, pert_image_output]
176
+ )
177
+
178
+
179
+ if __name__ == "__main__":
180
+ demo.launch(
181
+ server_name="0.0.0.0",
182
+ server_port=7860,
183
+ share=True,
184
+ debug=True,
185
+ show_error=True
186
+ )
gradio/run_caption.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Script to generate captions for images using the VLM model.
3
+ This script runs in the RobustMMFMEnv conda environment.
4
+ """
5
+
6
+ import argparse
7
+ import sys
8
+ import os
9
+ import warnings
10
+
11
+
12
+ warnings.filterwarnings('ignore')
13
+
14
+
15
+ # Add the parent directory to the path to import vlm_eval modules
16
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
17
+
18
+ def generate_caption(image_path, epsilon, sparsity, attack_algo, num_iters, model_name="open_flamingo", num_shots=0, targeted=False):
19
+ """
20
+ Generate caption for a single image.
21
+
22
+ Args:
23
+ image_path: Path to the image file
24
+ model_name: Name of the model to use
25
+ num_shots: Number of shots for few-shot learning
26
+
27
+ Returns:
28
+ str: Generated caption
29
+ """
30
+ try:
31
+ # Import required modules
32
+ from PIL import Image
33
+ import torch
34
+ import numpy as np
35
+ import tempfile
36
+ from open_flamingo.eval.models.of_eval_model_adv import EvalModelAdv
37
+ from open_flamingo.eval.coco_metric import postprocess_captioning_generation
38
+ from vlm_eval.attacks.apgd import APGD
39
+ from vlm_eval.attacks.saif import SAIF
40
+
41
+ # Model arguments
42
+ model_args = {
43
+ "lm_path": "togethercomputer/RedPajama-INCITE-Base-3B-v1",
44
+ "lm_tokenizer_path": "togethercomputer/RedPajama-INCITE-Base-3B-v1",
45
+ "vision_encoder_path": "ViT-L-14",
46
+ "vision_encoder_pretrained": "openai",
47
+ "checkpoint_path": "/home/kc/.cache/huggingface/hub/models--openflamingo--OpenFlamingo-4B-vitl-rpj3b/snapshots/df8d3f7e75bcf891ce2fbf5253a12f524692d9c2/checkpoint.pt",
48
+ "cross_attn_every_n_layers": "2",
49
+ "precision": "float16",
50
+ }
51
+
52
+ eval_model = EvalModelAdv(model_args, adversarial=True)
53
+ eval_model.set_device(0 if torch.cuda.is_available() else -1)
54
+
55
+ image = Image.open(image_path).convert("RGB")
56
+ image = eval_model._prepare_images([[image]])
57
+
58
+ prompt = eval_model.get_caption_prompt()
59
+
60
+ # Generate original caption
61
+ orig_caption = eval_model.get_outputs(
62
+ batch_images=image,
63
+ batch_text=[prompt], # Note: wrapped in list
64
+ min_generation_length=0,
65
+ max_generation_length=20,
66
+ num_beams=3,
67
+ length_penalty=-2.0,
68
+ )
69
+
70
+ #orig_caption = [postprocess_captioning_generation(out).replace('"', "") for out in orig_caption
71
+ #]
72
+
73
+
74
+
75
+ # For adversarial attack, create the adversarial text prompt
76
+ targeted = False # or True if you want targeted attack
77
+ target_str = "a dog" # your target if targeted=True
78
+ adv_caption = orig_caption[0] if not targeted else target_str
79
+ prompt_adv = eval_model.get_caption_prompt(adv_caption)
80
+
81
+ # ⭐ THIS IS THE CRITICAL MISSING STEP ⭐
82
+ eval_model.set_inputs(
83
+ batch_text=[prompt_adv], # Use adversarial prompt
84
+ past_key_values=None,
85
+ to_device=True,
86
+ )
87
+
88
+ # Now run the attack
89
+ if attack_algo == "APGD":
90
+ attack = APGD(
91
+ eval_model if not targeted else lambda x: -eval_model(x),
92
+ norm="linf",
93
+ eps=epsilon/255.0,
94
+ mask_out=None,
95
+ initial_stepsize=1.0,
96
+ )
97
+
98
+ adv_image = attack.perturb(
99
+ image.to(eval_model.device, dtype=eval_model.cast_dtype),
100
+ iterations=num_iters,
101
+ pert_init=None,
102
+ verbose=False,
103
+ )
104
+
105
+ elif attack_algo == "SAIF":
106
+ attack = SAIF(
107
+ model=eval_model,
108
+ targeted=targeted,
109
+ img_range=(0,1),
110
+ steps=num_iters,
111
+ mask_out=None,
112
+ eps=epsilon/255.0,
113
+ k=sparsity,
114
+ ver=False
115
+ )
116
+
117
+ adv_image, _ = attack(
118
+ x=image.to(eval_model.device, dtype=eval_model.cast_dtype),
119
+ )
120
+ else:
121
+ raise ValueError(f"Unsupported attack algorithm: {attack_algo}")
122
+
123
+ adv_image = adv_image.detach().cpu()
124
+
125
+ # Generate adversarial caption
126
+ adv_caption_output = eval_model.get_outputs(
127
+ batch_images=adv_image,
128
+ batch_text=[prompt], # Use clean prompt for generation
129
+ min_generation_length=0,
130
+ max_generation_length=20,
131
+ num_beams=3,
132
+ length_penalty=-2.0,
133
+ )
134
+ new_predictions = [
135
+ postprocess_captioning_generation(out).replace('"', "") for out in adv_caption_output
136
+ ]
137
+
138
+ # At the end, instead of:
139
+ # print(orig_caption[0])
140
+ # print(new_predictions[0])
141
+
142
+ # Do this - strip the list and get just the string:
143
+ #print(orig_caption)
144
+
145
+ orig_img_np = image.view(1,3,224,224).squeeze(0).cpu().permute(1, 2, 0).numpy()
146
+ adv_img_np = adv_image.view(1,3,224,224).squeeze(0).cpu().permute(1, 2, 0).numpy()
147
+
148
+ # Calculate perturbation (difference between adversarial and original)
149
+ perturbation = adv_img_np - orig_img_np
150
+ # Magnify by 10x for visualization
151
+ perturbation_magnified = perturbation * 10
152
+
153
+ # Normalize to [0, 255] for display
154
+ orig_img_np = ((orig_img_np - orig_img_np.min()) / (orig_img_np.max() - orig_img_np.min()) * 255).astype(np.uint8)
155
+ adv_img_np = ((adv_img_np - adv_img_np.min()) / (adv_img_np.max() - adv_img_np.min()) * 255).astype(np.uint8)
156
+
157
+ # Normalize perturbation to [0, 255] for visualization
158
+ pert_img_np = ((perturbation_magnified - perturbation_magnified.min()) /
159
+ (perturbation_magnified.max() - perturbation_magnified.min()) * 255).astype(np.uint8)
160
+
161
+ # ✅ Save images to temporary files
162
+ with tempfile.NamedTemporaryFile(mode='wb', suffix='.png', delete=False) as f:
163
+ orig_img_path = f.name
164
+ Image.fromarray(orig_img_np).save(orig_img_path)
165
+
166
+ with tempfile.NamedTemporaryFile(mode='wb', suffix='.png', delete=False) as f:
167
+ adv_img_path = f.name
168
+ Image.fromarray(adv_img_np).save(adv_img_path)
169
+
170
+ with tempfile.NamedTemporaryFile(mode='wb', suffix='.png', delete=False) as f:
171
+ pert_img_path = f.name
172
+ Image.fromarray(pert_img_np).save(pert_img_path)
173
+
174
+ results = {
175
+ "original_caption": orig_caption[0],
176
+ "adversarial_caption": new_predictions[0],
177
+ "original_image_path": orig_img_path, # Return file paths
178
+ "adversarial_image_path": adv_img_path,
179
+ "perturbation_image_path": pert_img_path
180
+ }
181
+
182
+ return results
183
+
184
+ except Exception as e:
185
+ import traceback
186
+ error_msg = f"Error in caption generation: {str(e)}\n{traceback.format_exc()}"
187
+ print(error_msg, file=sys.stderr, flush=True)
188
+ # Return dict with error information
189
+ return {
190
+ "original_caption": f"Error: {str(e)}",
191
+ "adversarial_caption": "",
192
+ "original_image_path": None,
193
+ "adversarial_image_path": None,
194
+ "perturbation_image_path": None
195
+ }
196
+
197
+ def main():
198
+ parser = argparse.ArgumentParser(description="Generate caption for an image")
199
+ parser.add_argument("--image_path", type=str, required=True, help="Path to the image")
200
+ parser.add_argument("--model", type=str, default="open_flamingo", help="Model to use")
201
+ parser.add_argument("--shots", type=int, default=0, help="Number of shots")
202
+ parser.add_argument("--epsilon", type=float, default=8.0, help="Epsilon for adversarial attack")
203
+ parser.add_argument("--sparsity", type=int, default=0, help="Sparsity for SAIF attack")
204
+ parser.add_argument("--attack_algo", type=str, default="APGD", help="Adversarial attack algorithm (APGD or SAIF)")
205
+ parser.add_argument("--num_iters", type=int, default=100, help="Number of iterations for adversarial attack")
206
+
207
+ args = parser.parse_args()
208
+
209
+ # Generate caption
210
+ caption = generate_caption(args.image_path, args.epsilon, args.sparsity, args.attack_algo, args.num_iters, args.model, args.shots)
211
+
212
+ if caption:
213
+ print(caption)
214
+ sys.exit(0)
215
+ else:
216
+ print("Failed to generate caption", file=sys.stderr)
217
+ sys.exit(1)
218
+
219
+
220
+ if __name__ == "__main__":
221
+ main()