| import streamlit as st | |
| import os | |
| import sys | |
| import tempfile | |
| import zipfile | |
| import json | |
| import random | |
| import math | |
| import csv | |
| from pathlib import Path | |
| from datetime import datetime | |
| import time | |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) | |
| script_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'scripts') | |
| sys.path.insert(0, script_dir) | |
| try: | |
| from pipeline import ( | |
| generate_counterfactuals, | |
| generate_base_scene, | |
| save_scene, | |
| render_scene, | |
| create_patched_render_script, | |
| IMAGE_COUNTERFACTUALS, | |
| NEGATIVE_COUNTERFACTUALS | |
| ) | |
| try: | |
| import sys | |
| script_dir = os.path.dirname(os.path.abspath(__file__)) | |
| scripts_path = os.path.join(script_dir, 'scripts') | |
| if scripts_path not in sys.path: | |
| sys.path.insert(0, scripts_path) | |
| from generate_questions_mapping import ( | |
| load_scene, | |
| generate_question_for_scene as _generate_question_for_scene_file, | |
| answer_question_for_scene, | |
| generate_mapping_with_questions | |
| ) | |
| except ImportError: | |
| def load_scene(scene_file): | |
| with open(scene_file, 'r') as f: | |
| return json.load(f) | |
| def answer_question_for_scene(question, scene): | |
| objects = scene.get('objects', []) | |
| return len(objects) | |
| _generate_question_for_scene_file = None | |
| generate_mapping_with_questions = None | |
| PIPELINE_AVAILABLE = True | |
| except ImportError as e: | |
| print(f"Warning: Error importing pipeline functions: {e}") | |
| PIPELINE_AVAILABLE = False | |
| answer_question_for_scene = None | |
| load_scene = None | |
| _generate_question_for_scene_file = None | |
| st.set_page_config( | |
| page_title="Counterfactual Image Generator", | |
| page_icon="🎨", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| st.markdown(""" | |
| <style> | |
| .main-header { | |
| font-size: 2.5rem; | |
| font-weight: bold; | |
| color: #1f77b4; | |
| text-align: center; | |
| margin-bottom: 2rem; | |
| } | |
| .stButton>button { | |
| width: 100%; | |
| height: 3.5rem; | |
| font-size: 1.2rem; | |
| font-weight: bold; | |
| background-color: #1f77b4; | |
| color: white; | |
| border-radius: 0.5rem; | |
| } | |
| .stButton>button:hover { | |
| background-color: #1565c0; | |
| } | |
| .info-box { | |
| padding: 1rem; | |
| border-radius: 0.5rem; | |
| background-color: #f0f2f6; | |
| margin: 1rem 0; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| def create_zip_file(output_dir, zip_path): | |
| with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf: | |
| for root, dirs, files in os.walk(output_dir): | |
| for file in files: | |
| file_path = os.path.join(root, file) | |
| arcname = os.path.relpath(file_path, output_dir) | |
| zipf.write(file_path, arcname) | |
| def generate_fallback_scene(num_objects, scene_idx): | |
| script_dir = os.path.dirname(os.path.abspath(__file__)) | |
| props_path = os.path.join(script_dir, 'data', 'properties.json') | |
| try: | |
| with open(props_path, 'r') as f: | |
| properties = json.load(f) | |
| except: | |
| properties = { | |
| 'shapes': {'cube': 'SmoothCube_v2', 'sphere': 'Sphere', 'cylinder': 'SmoothCylinder'}, | |
| 'colors': {'gray': [87, 87, 87], 'red': [173, 35, 35], 'blue': [42, 75, 215], | |
| 'green': [29, 105, 20], 'brown': [129, 74, 25], 'purple': [129, 38, 192], | |
| 'cyan': [41, 208, 208], 'yellow': [255, 238, 51]}, | |
| 'materials': {'rubber': 'Rubber', 'metal': 'MyMetal'}, | |
| 'sizes': {'large': 0.7, 'small': 0.35} | |
| } | |
| shapes = list(properties['shapes'].keys()) | |
| colors = list(properties['colors'].keys()) | |
| materials = list(properties['materials'].keys()) | |
| sizes = list(properties['sizes'].keys()) | |
| scene_num = scene_idx + 1 | |
| scene = { | |
| 'split': 'fallback', | |
| 'image_index': scene_num, | |
| 'image_filename': f'scene_{scene_num:04d}_original.png', | |
| 'objects': [], | |
| 'directions': { | |
| 'behind': (0.0, -1.0, 0.0), | |
| 'front': (0.0, 1.0, 0.0), | |
| 'left': (-1.0, 0.0, 0.0), | |
| 'right': (1.0, 0.0, 0.0), | |
| 'above': (0.0, 0.0, 1.0), | |
| 'below': (0.0, 0.0, -1.0) | |
| } | |
| } | |
| positions = [] | |
| min_dist = 0.25 | |
| for i in range(num_objects): | |
| max_attempts = 100 | |
| placed = False | |
| for attempt in range(max_attempts): | |
| x = random.uniform(-3, 3) | |
| y = random.uniform(-3, 3) | |
| z = random.uniform(0.35, 0.7) | |
| collision = False | |
| size = random.choice(sizes) | |
| r = properties['sizes'][size] | |
| for (px, py, pz, pr) in positions: | |
| dist = math.sqrt((x - px)**2 + (y - py)**2) | |
| if dist < (r + pr + min_dist): | |
| collision = True | |
| break | |
| if not collision: | |
| positions.append((x, y, z, r)) | |
| placed = True | |
| break | |
| if not placed: | |
| x = random.uniform(-3, 3) | |
| y = random.uniform(-3, 3) | |
| z = random.uniform(0.35, 0.7) | |
| size = random.choice(sizes) | |
| r = properties['sizes'][size] | |
| positions.append((x, y, z, r)) | |
| shape = random.choice(shapes) | |
| color = random.choice(colors) | |
| material = random.choice(materials) | |
| obj = { | |
| 'shape': shape, | |
| 'size': size, | |
| 'material': material, | |
| '3d_coords': [x, y, z], | |
| 'rotation': random.uniform(0, 360), | |
| 'pixel_coords': [0, 0, 0], | |
| 'color': color | |
| } | |
| scene['objects'].append(obj) | |
| return scene | |
| def generate_question_for_scene_dict(scene): | |
| if _generate_question_for_scene_file is None: | |
| objects = scene.get('objects', []) | |
| if len(objects) == 0: | |
| return "How many objects are in the scene?", {} | |
| colors = list(set(obj.get('color') for obj in objects if obj.get('color'))) | |
| shapes = list(set(obj.get('shape') for obj in objects if obj.get('shape'))) | |
| if colors: | |
| return f"How many {random.choice(colors)} objects are there?", {'color': random.choice(colors)} | |
| else: | |
| return "How many objects are in the scene?", {} | |
| with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as tmp_file: | |
| json.dump(scene, tmp_file) | |
| tmp_path = tmp_file.name | |
| try: | |
| question, params = _generate_question_for_scene_file(tmp_path) | |
| return question, params | |
| finally: | |
| try: | |
| os.unlink(tmp_path) | |
| except: | |
| pass | |
| def generate_counterfactual_scenes(num_scenes, num_objects, min_objects, max_objects, num_counterfactuals, | |
| cf_types, same_cf_type, min_change_score, max_cf_attempts, min_noise_level, | |
| output_dir, blender_path=None, use_gpu=0, samples=512, | |
| width=320, height=240, skip_render=False, generate_questions=False, | |
| semantic_only=False, negative_only=False): | |
| if not PIPELINE_AVAILABLE: | |
| return { | |
| 'success': False, | |
| 'error': 'Pipeline functions not available. Please ensure pipeline.py is accessible.' | |
| } | |
| scenes_dir = os.path.join(output_dir, 'scenes') | |
| images_dir = os.path.join(output_dir, 'images') | |
| os.makedirs(scenes_dir, exist_ok=True) | |
| os.makedirs(images_dir, exist_ok=True) | |
| script_dir = os.path.dirname(os.path.abspath(__file__)) | |
| cwd = os.getcwd() | |
| import shutil | |
| import time | |
| temp_output_dir = os.path.join(cwd, 'temp_output') | |
| if os.path.exists(temp_output_dir): | |
| for attempt in range(3): | |
| try: | |
| shutil.rmtree(temp_output_dir) | |
| break | |
| except Exception as e: | |
| if attempt < 2: | |
| time.sleep(0.3) | |
| else: | |
| print(f"Warning: Could not remove temp_output after 3 attempts: {e}") | |
| render_patched_path = os.path.join(cwd, 'render_images_patched.py') | |
| if os.path.exists(render_patched_path): | |
| for attempt in range(3): | |
| try: | |
| time.sleep(0.2) | |
| if os.path.exists(render_patched_path): | |
| os.remove(render_patched_path) | |
| break | |
| except Exception as e: | |
| if attempt < 2: | |
| time.sleep(0.3) | |
| else: | |
| print(f"Warning: Could not remove render_images_patched.py after 3 attempts: {e}") | |
| blender_available = False | |
| if blender_path is None: | |
| try: | |
| from pipeline import find_blender | |
| blender_path = find_blender() | |
| except: | |
| blender_path = 'blender' | |
| if blender_path and blender_path != 'blender': | |
| blender_available = os.path.exists(blender_path) | |
| else: | |
| try: | |
| import subprocess | |
| test_path = blender_path if blender_path and blender_path != 'blender' else 'blender' | |
| env = os.environ.copy() | |
| result = subprocess.run([test_path, '--version'], capture_output=True, timeout=5, env=env) | |
| blender_available = (result.returncode == 0) | |
| except: | |
| blender_available = False | |
| successful_scenes = 0 | |
| successful_renders = 0 | |
| error_messages = [] | |
| try: | |
| for scene_idx in range(num_scenes): | |
| if num_objects is not None: | |
| scene_num_objects = num_objects | |
| else: | |
| scene_num_objects = random.randint(min_objects, max_objects) | |
| base_scene = None | |
| if blender_available: | |
| scene_error = None | |
| for retry in range(3): | |
| try: | |
| import io | |
| import contextlib | |
| output_buffer = io.StringIO() | |
| with contextlib.redirect_stdout(output_buffer), contextlib.redirect_stderr(output_buffer): | |
| base_scene = generate_base_scene( | |
| scene_num_objects, | |
| blender_path, | |
| scene_idx | |
| ) | |
| blender_output = output_buffer.getvalue() | |
| if blender_output and retry == 2: | |
| st.text(f"Blender output for scene {scene_idx + 1} (last 1000 chars):") | |
| st.code(blender_output[-1000:] if len(blender_output) > 1000 else blender_output) | |
| if base_scene and len(base_scene.get('objects', [])) > 0: | |
| break | |
| elif base_scene is None: | |
| if retry == 2: | |
| scene_error = f"generate_base_scene returned None - Blender may have failed (check output above)" | |
| error_messages.append(f"Scene {scene_idx + 1}: {scene_error}") | |
| elif len(base_scene.get('objects', [])) == 0: | |
| if retry == 2: | |
| scene_error = f"Scene has 0 objects - Blender may have hit max_retries (check output above)" | |
| error_messages.append(f"Scene {scene_idx + 1}: {scene_error}") | |
| except FileNotFoundError as e: | |
| scene_error = f"Blender not found: {e}" | |
| error_messages.append(f"Scene {scene_idx + 1}: {scene_error}") | |
| blender_available = False | |
| break | |
| except Exception as e: | |
| import traceback | |
| scene_error = f"Error generating base scene: {str(e)}" | |
| print(f"Error generating base scene (retry {retry + 1}/3): {e}") | |
| print(f" Traceback: {traceback.format_exc()}") | |
| if retry == 2: | |
| full_error = f"Scene {scene_idx + 1}: {scene_error} (Blender path: {blender_path})" | |
| error_messages.append(full_error) | |
| blender_available = False | |
| continue | |
| else: | |
| print(f"Scene {scene_idx + 1} (Blender not available)...") | |
| base_scene = generate_fallback_scene(scene_num_objects, scene_idx) | |
| if not base_scene or len(base_scene.get('objects', [])) == 0: | |
| error_detail = f"Scene {scene_idx + 1}: Failed to generate" | |
| if blender_available: | |
| error_detail += f" (Blender was available at {blender_path} but returned empty scene)" | |
| else: | |
| error_detail += " (Blender not available, fallback scene also failed)" | |
| print(f"Failed to generate scene {scene_idx + 1}") | |
| print(f" Blender available: {blender_available}") | |
| print(f" Blender path: {blender_path}") | |
| print(f" Base scene: {base_scene is not None}") | |
| if base_scene: | |
| print(f" Objects in scene: {len(base_scene.get('objects', []))}") | |
| error_messages.append(error_detail) | |
| continue | |
| successful_scenes += 1 | |
| counterfactuals = generate_counterfactuals( | |
| base_scene, | |
| num_counterfactuals=num_counterfactuals, | |
| cf_types=cf_types, | |
| same_cf_type=same_cf_type, | |
| min_change_score=min_change_score, | |
| max_cf_attempts=max_cf_attempts, | |
| min_noise_level='light', | |
| semantic_only=semantic_only, | |
| negative_only=negative_only | |
| ) | |
| scene_num = scene_idx + 1 | |
| scene_prefix = f"scene_{scene_num:04d}" | |
| base_scene['cf_metadata'] = { | |
| 'variant': 'original', | |
| 'is_counterfactual': False, | |
| 'cf_index': None, | |
| 'cf_category': 'original', | |
| 'cf_type': None, | |
| 'cf_description': None, | |
| 'source_scene': scene_prefix, | |
| } | |
| original_scene_path = os.path.join(scenes_dir, f"{scene_prefix}_original.json") | |
| save_scene(base_scene, original_scene_path) | |
| for idx, cf in enumerate(counterfactuals): | |
| cf_name = f"cf{idx+1}" | |
| cf_scene = cf['scene'] | |
| cf_scene['cf_metadata'] = { | |
| 'variant': cf_name, | |
| 'is_counterfactual': True, | |
| 'cf_index': idx + 1, | |
| 'cf_category': cf.get('cf_category', 'unknown'), | |
| 'cf_type': cf.get('type', None), | |
| 'cf_description': cf.get('description', None), | |
| 'change_score': cf.get('change_score', None), | |
| 'change_attempts': cf.get('change_attempts', None), | |
| 'source_scene': scene_prefix, | |
| } | |
| cf_scene_path = os.path.join(scenes_dir, f"{scene_prefix}_{cf_name}.json") | |
| save_scene(cf_scene, cf_scene_path) | |
| render_success = 0 | |
| total_to_render = len(counterfactuals) + 1 | |
| if not skip_render: | |
| if blender_path and blender_available: | |
| original_image_path = os.path.join(images_dir, f"{scene_prefix}_original.png") | |
| if render_scene( | |
| blender_path, | |
| original_scene_path, | |
| original_image_path, | |
| use_gpu=use_gpu, | |
| samples=samples, | |
| width=width, | |
| height=height | |
| ): | |
| render_success += 1 | |
| for idx, cf in enumerate(counterfactuals): | |
| cf_name = f"cf{idx+1}" | |
| cf_scene_path = os.path.join(scenes_dir, f"{scene_prefix}_{cf_name}.json") | |
| cf_image_path = os.path.join(images_dir, f"{scene_prefix}_{cf_name}.png") | |
| if render_scene( | |
| blender_path, | |
| cf_scene_path, | |
| cf_image_path, | |
| use_gpu=use_gpu, | |
| samples=samples, | |
| width=width, | |
| height=height | |
| ): | |
| render_success += 1 | |
| if render_success == total_to_render: | |
| successful_renders += 1 | |
| else: | |
| print("Blender not available - skipping image rendering. Scene JSON files will still be generated.") | |
| csv_filename = 'image_mapping_with_questions.csv' if generate_questions else 'image_mapping.csv' | |
| csv_path = os.path.join(output_dir, csv_filename) | |
| try: | |
| if generate_mapping_with_questions is not None: | |
| generate_mapping_with_questions( | |
| run_dir=output_dir, | |
| csv_filename=csv_filename, | |
| generate_questions=generate_questions, | |
| with_links=False, | |
| strict_question_validation=True | |
| ) | |
| csv_created = os.path.exists(csv_path) | |
| else: | |
| csv_created = False | |
| except Exception: | |
| import traceback | |
| traceback.print_exc() | |
| csv_created = False | |
| scene_files = list(Path(scenes_dir).glob("*.json")) if os.path.exists(scenes_dir) else [] | |
| image_files = list(Path(images_dir).glob("*.png")) if os.path.exists(images_dir) else [] | |
| statistics = { | |
| 'scenes_generated': successful_scenes, | |
| 'scenes_rendered': successful_renders, | |
| 'total_scene_files': len(scene_files), | |
| 'total_image_files': len(image_files), | |
| 'num_counterfactuals': num_counterfactuals, | |
| 'cf_types_used': cf_types if cf_types else 'default', | |
| 'csv_created': csv_created, | |
| 'csv_path': csv_path if csv_created else None | |
| } | |
| script_dir = os.path.dirname(os.path.abspath(__file__)) | |
| cwd = os.getcwd() | |
| import shutil | |
| import time | |
| temp_output_dir = os.path.join(cwd, 'temp_output') | |
| if os.path.exists(temp_output_dir): | |
| for attempt in range(3): | |
| try: | |
| shutil.rmtree(temp_output_dir) | |
| break | |
| except Exception as e: | |
| if attempt < 2: | |
| time.sleep(0.3) | |
| else: | |
| print(f"Warning: Could not remove temp_output after 3 attempts: {e}") | |
| render_patched_path = os.path.join(cwd, 'render_images_patched.py') | |
| if os.path.exists(render_patched_path): | |
| for attempt in range(3): | |
| try: | |
| time.sleep(0.2) | |
| if os.path.exists(render_patched_path): | |
| os.remove(render_patched_path) | |
| break | |
| except Exception as e: | |
| if attempt < 2: | |
| time.sleep(0.3) | |
| else: | |
| print(f"Warning: Could not remove render_images_patched.py after 3 attempts: {e}") | |
| if successful_scenes == 0 and error_messages: | |
| error_summary = "Scenes failed. Common reasons:\n" | |
| error_summary += "- Blender is not installed or not in PATH\n" | |
| error_summary += "- Blender executable not found\n" | |
| error_summary += f"\nFirst error: {error_messages[0] if error_messages else 'Unknown error'}" | |
| return { | |
| 'success': False, | |
| 'error': error_summary, | |
| 'num_scenes': successful_scenes, | |
| 'output_dir': output_dir, | |
| 'error_messages': error_messages | |
| } | |
| return { | |
| 'success': True, | |
| 'num_scenes': successful_scenes, | |
| 'output_dir': output_dir, | |
| 'statistics': statistics, | |
| 'error_messages': error_messages if error_messages else None | |
| } | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"Error: {str(e)}\n{traceback.format_exc()}" | |
| print(error_msg) | |
| return { | |
| 'success': False, | |
| 'error': error_msg, | |
| 'num_scenes': successful_scenes, | |
| 'output_dir': output_dir, | |
| 'error_messages': error_messages if 'error_messages' in locals() else [] | |
| } | |
| def main(): | |
| st.markdown('<p class="main-header">Counterfactual Image Generator</p>', unsafe_allow_html=True) | |
| if 'output_dir' not in st.session_state: | |
| st.session_state.output_dir = None | |
| if 'generation_complete' not in st.session_state: | |
| st.session_state.generation_complete = False | |
| with st.sidebar: | |
| st.header("Configuration") | |
| st.subheader("Scene Settings") | |
| num_scenes = st.number_input( | |
| "Number of Scenes", | |
| min_value=1, | |
| max_value=10000, | |
| value=5, | |
| help="Number of scene sets to generate" | |
| ) | |
| use_fixed_objects = st.checkbox("Use Fixed Number of Objects", value=True) | |
| if use_fixed_objects: | |
| num_objects = st.number_input( | |
| "Number of Objects per Scene", | |
| min_value=1, | |
| max_value=15, | |
| value=5, | |
| help="Fixed number of objects per scene" | |
| ) | |
| min_objects = None | |
| max_objects = None | |
| else: | |
| num_objects = None | |
| min_objects = st.number_input( | |
| "Min Objects per Scene", | |
| min_value=1, | |
| max_value=15, | |
| value=3, | |
| help="Minimum objects per scene" | |
| ) | |
| max_objects = st.number_input( | |
| "Max Objects per Scene", | |
| min_value=1, | |
| max_value=15, | |
| value=7, | |
| help="Maximum objects per scene" | |
| ) | |
| if min_objects > max_objects: | |
| st.error("Min objects must be <= Max objects") | |
| return | |
| st.subheader("Counterfactual Settings") | |
| num_counterfactuals = st.number_input( | |
| "Number of Counterfactuals", | |
| min_value=1, | |
| max_value=10, | |
| value=2, | |
| help="Number of counterfactual variants per scene" | |
| ) | |
| st.markdown("**Counterfactual Types**") | |
| st.caption("Leave all unchecked to use default behavior (1 Image CF + 1 Negative CF)") | |
| semantic_only = st.checkbox( | |
| "Semantic only", | |
| value=False, | |
| help="Generate only Semantic/Image counterfactuals (Change Color, Add Object, etc.); no Negative CFs" | |
| ) | |
| negative_only = st.checkbox( | |
| "Negative only", | |
| value=False, | |
| help="Generate only Negative counterfactuals (Change Lighting, Add Noise, Occlusion Change, etc.); no Semantic CFs" | |
| ) | |
| same_cf_type = st.checkbox( | |
| "Same CF type for all", | |
| value=False, | |
| help="Use the same counterfactual type for every variant (first selected type, or one random if none selected)" | |
| ) | |
| with st.expander("Image CFs (change answers)", expanded=True): | |
| use_change_color = st.checkbox("Change Color", value=False) | |
| use_change_shape = st.checkbox("Change Shape", value=False) | |
| use_change_size = st.checkbox("Change Size", value=False) | |
| use_change_material = st.checkbox("Change Material", value=False) | |
| use_change_position = st.checkbox("Change Position", value=False) | |
| use_add_object = st.checkbox("Add Object", value=False) | |
| use_remove_object = st.checkbox("Remove Object", value=False) | |
| use_replace_object = st.checkbox("Replace Object", value=False) | |
| use_swap_attribute = st.checkbox("Swap Attribute", value=False) | |
| use_relational_flip = st.checkbox("Relational Flip", value=False) | |
| with st.expander("Negative CFs (don't change answers)", expanded=False): | |
| use_change_background = st.checkbox("Change Background", value=False) | |
| use_change_lighting = st.checkbox("Change Lighting", value=False) | |
| use_add_noise = st.checkbox("Add Noise", value=False) | |
| use_occlusion_change = st.checkbox("Occlusion Change", value=False) | |
| use_apply_fisheye = st.checkbox("Apply Fisheye", value=False) | |
| use_apply_blur = st.checkbox("Apply Blur", value=False) | |
| use_apply_vignette = st.checkbox("Apply Vignette", value=False) | |
| use_apply_chromatic_aberration = st.checkbox("Apply Chromatic Aberration", value=False) | |
| with st.expander("Advanced Settings", expanded=False): | |
| min_change_score = st.slider( | |
| "Minimum Change Score", | |
| min_value=0.5, | |
| max_value=5.0, | |
| value=1.0, | |
| step=0.1, | |
| help="Minimum heuristic change score for counterfactuals" | |
| ) | |
| max_cf_attempts = st.number_input( | |
| "Max CF Attempts", | |
| min_value=1, | |
| max_value=50, | |
| value=10, | |
| help="Maximum retries per counterfactual" | |
| ) | |
| min_noise_level = st.selectbox( | |
| "Min Noise Level (for add_noise CF)", | |
| options=['light', 'medium', 'heavy'], | |
| index=0, | |
| help="Minimum noise level when using add_noise counterfactual" | |
| ) | |
| st.markdown("---") | |
| st.markdown("**Rendering Settings**") | |
| use_gpu = st.checkbox("Use GPU Rendering", value=False) | |
| use_gpu_int = 1 if use_gpu else 0 | |
| samples = st.number_input( | |
| "Render Samples", | |
| min_value=64, | |
| max_value=2048, | |
| value=512, | |
| step=64, | |
| help="Cycles sampling rate (higher = better quality, slower)" | |
| ) | |
| image_width = st.number_input( | |
| "Image Width", | |
| min_value=160, | |
| max_value=1920, | |
| value=320, | |
| step=80 | |
| ) | |
| image_height = st.number_input( | |
| "Image Height", | |
| min_value=120, | |
| max_value=1080, | |
| value=240, | |
| step=60 | |
| ) | |
| st.markdown("**CSV Options**") | |
| generate_questions = st.checkbox( | |
| "Generate Questions in CSV", | |
| value=False, | |
| help="Include question and answer columns in the CSV file" | |
| ) | |
| cf_types = [] | |
| if use_change_color: | |
| cf_types.append('change_color') | |
| if use_change_shape: | |
| cf_types.append('change_shape') | |
| if use_change_size: | |
| cf_types.append('change_size') | |
| if use_change_material: | |
| cf_types.append('change_material') | |
| if use_change_position: | |
| cf_types.append('change_position') | |
| if use_add_object: | |
| cf_types.append('add_object') | |
| if use_remove_object: | |
| cf_types.append('remove_object') | |
| if use_replace_object: | |
| cf_types.append('replace_object') | |
| if use_swap_attribute: | |
| cf_types.append('swap_attribute') | |
| if use_relational_flip: | |
| cf_types.append('relational_flip') | |
| if use_change_background: | |
| cf_types.append('change_background') | |
| if use_change_lighting: | |
| cf_types.append('change_lighting') | |
| if use_add_noise: | |
| cf_types.append('add_noise') | |
| if use_occlusion_change: | |
| cf_types.append('occlusion_change') | |
| if use_apply_fisheye: | |
| cf_types.append('apply_fisheye') | |
| if use_apply_blur: | |
| cf_types.append('apply_blur') | |
| if use_apply_vignette: | |
| cf_types.append('apply_vignette') | |
| if use_apply_chromatic_aberration: | |
| cf_types.append('apply_chromatic_aberration') | |
| if not cf_types: | |
| cf_types = None | |
| col1, col2 = st.columns([2, 1]) | |
| with col1: | |
| st.header("Generate Counterfactual Images") | |
| if st.button("Generate Counterfactual", use_container_width=True, key="generate_button"): | |
| st.session_state.generation_complete = False | |
| st.session_state.generating = True | |
| if num_scenes < 1: | |
| st.error("Please specify at least 1 scene to generate.") | |
| return | |
| if use_fixed_objects and num_objects < 1: | |
| st.error("Please specify at least 1 object per scene.") | |
| return | |
| if not use_fixed_objects and (min_objects < 1 or max_objects < 1 or min_objects > max_objects): | |
| st.error("Invalid min/max objects configuration.") | |
| return | |
| if os.path.exists('/tmp'): | |
| base_dir = '/tmp' | |
| else: | |
| base_dir = tempfile.gettempdir() | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| output_dir = os.path.join(base_dir, f"counterfactual_output_{timestamp}") | |
| os.makedirs(output_dir, exist_ok=True) | |
| st.session_state.output_dir = output_dir | |
| import shutil | |
| import time | |
| script_dir = os.path.dirname(os.path.abspath(__file__)) | |
| cwd = os.getcwd() | |
| temp_output_dir = os.path.join(cwd, 'temp_output') | |
| if os.path.exists(temp_output_dir): | |
| for attempt in range(3): | |
| try: | |
| shutil.rmtree(temp_output_dir) | |
| break | |
| except Exception as e: | |
| if attempt < 2: | |
| time.sleep(0.3) | |
| else: | |
| print(f"Warning: Could not remove temp_output after 3 attempts: {e}") | |
| render_patched_path = os.path.join(cwd, 'render_images_patched.py') | |
| if os.path.exists(render_patched_path): | |
| for attempt in range(3): | |
| try: | |
| time.sleep(0.2) | |
| if os.path.exists(render_patched_path): | |
| os.remove(render_patched_path) | |
| break | |
| except Exception as e: | |
| if attempt < 2: | |
| time.sleep(0.3) | |
| else: | |
| print(f"Warning: Could not remove render_images_patched.py after 3 attempts: {e}") | |
| try: | |
| from pipeline import create_patched_render_script | |
| create_patched_render_script() | |
| except Exception as e: | |
| st.warning(f"Could not create patched render script: {e}") | |
| params = { | |
| 'num_scenes': num_scenes, | |
| 'num_objects': num_objects, | |
| 'num_counterfactuals': num_counterfactuals, | |
| 'cf_types': cf_types if cf_types else None, | |
| 'same_cf_type': same_cf_type, | |
| 'min_change_score': min_change_score, | |
| 'max_cf_attempts': max_cf_attempts, | |
| 'width': image_width, | |
| 'height': image_height, | |
| 'output_dir': output_dir | |
| } | |
| progress_bar = st.progress(0) | |
| status_text = st.empty() | |
| try: | |
| if not PIPELINE_AVAILABLE: | |
| st.error("Pipeline functions are not available. Please check your installation.") | |
| return | |
| status_text.text("Initializing generator...") | |
| progress_bar.progress(10) | |
| if use_fixed_objects: | |
| status_text.text(f"Generating {num_scenes} scenes with {num_objects} objects each...") | |
| else: | |
| status_text.text(f"Generating {num_scenes} scenes with {min_objects}-{max_objects} objects each...") | |
| progress_bar.progress(30) | |
| result = generate_counterfactual_scenes( | |
| num_scenes=num_scenes, | |
| num_objects=num_objects, | |
| min_objects=min_objects, | |
| max_objects=max_objects, | |
| num_counterfactuals=num_counterfactuals, | |
| cf_types=cf_types, | |
| same_cf_type=same_cf_type, | |
| min_change_score=min_change_score, | |
| max_cf_attempts=max_cf_attempts, | |
| min_noise_level=min_noise_level, | |
| output_dir=output_dir, | |
| use_gpu=use_gpu_int, | |
| samples=samples, | |
| width=image_width, | |
| height=image_height, | |
| skip_render=False, | |
| generate_questions=generate_questions, | |
| semantic_only=semantic_only, | |
| negative_only=negative_only | |
| ) | |
| progress_bar.progress(80) | |
| status_text.text("Preparing output...") | |
| if result and result.get('success', False): | |
| num_scenes_generated = result.get('num_scenes', 0) | |
| if num_scenes_generated == 0: | |
| st.warning("No scenes were created. Blender is required and is not available in this environment.") | |
| st.info("**To use this application:**\n" | |
| "1. Run it locally with Blender installed\n" | |
| "2. Use the command-line `pipeline.py` script\n" | |
| "3. Install Blender and ensure it's in your system PATH") | |
| st.session_state.generation_complete = False | |
| else: | |
| st.session_state.generation_complete = True | |
| progress_bar.progress(100) | |
| status_text.text("Done.") | |
| st.success(f"Successfully generated {num_scenes_generated} scene sets!") | |
| st.info(f"Output directory: {output_dir}") | |
| if 'statistics' in result and result['statistics'].get('csv_created'): | |
| csv_path = result['statistics'].get('csv_path') | |
| if csv_path: | |
| st.success(f"CSV file created: `{os.path.basename(csv_path)}`") | |
| if 'statistics' in result: | |
| stats = result['statistics'] | |
| st.json(stats) | |
| else: | |
| error_msg = result.get('error', 'Unknown error occurred') if result else 'Failed' | |
| st.error(f"Generation failed: {error_msg}") | |
| if 'blender' in error_msg.lower() or 'Blender' in error_msg or result.get('num_scenes', 0) == 0: | |
| st.warning("**Important:** This application requires Blender to generate scenes. Blender is not available on Hugging Face Spaces.") | |
| st.info("**To use this application:**\n" | |
| "1. Run it locally with Blender installed\n" | |
| "2. Use the command-line `pipeline.py` script\n" | |
| "3. Install Blender and ensure it's in your system PATH") | |
| st.session_state.generation_complete = False | |
| st.session_state.generating = False | |
| except Exception as e: | |
| st.error(f"Error during generation: {str(e)}") | |
| st.exception(e) | |
| st.session_state.generation_complete = False | |
| st.session_state.generating = False | |
| progress_bar.progress(0) | |
| status_text.text("Failed") | |
| with col2: | |
| st.header("Output") | |
| if st.session_state.generation_complete and st.session_state.output_dir: | |
| output_dir = st.session_state.output_dir | |
| if os.path.exists(output_dir): | |
| images_dir = os.path.join(output_dir, 'images') | |
| scenes_dir = os.path.join(output_dir, 'scenes') | |
| scene_files = list(Path(scenes_dir).glob("*.json")) if os.path.exists(scenes_dir) else [] | |
| image_files = list(Path(images_dir).glob("*.png")) if os.path.exists(images_dir) else [] | |
| csv_files = list(Path(output_dir).rglob("*.csv")) | |
| st.success("Complete!") | |
| st.metric("Scene Files", len(scene_files)) | |
| st.metric("CSV Files", len(csv_files)) | |
| st.metric("Image Files", len(image_files)) | |
| if image_files: | |
| st.markdown("---") | |
| st.subheader("Generated Images") | |
| def get_counterfactual_type_from_scene(scene_file): | |
| try: | |
| with open(scene_file, 'r') as f: | |
| scene_data = json.load(f) | |
| cf_metadata = scene_data.get('cf_metadata', {}) | |
| cf_type = cf_metadata.get('cf_type', '') | |
| if cf_type: | |
| return cf_type.replace('_', ' ').title() | |
| except Exception as e: | |
| pass | |
| return "Counterfactual" | |
| scene_sets = {} | |
| for img_file in image_files: | |
| filename = img_file.name | |
| if filename.startswith('scene_'): | |
| parts = filename.replace('.png', '').split('_') | |
| if len(parts) >= 3: | |
| scene_num = parts[1] | |
| scene_type = parts[2] | |
| if scene_num not in scene_sets: | |
| scene_sets[scene_num] = {} | |
| scene_sets[scene_num][scene_type] = { | |
| 'image_path': str(img_file), | |
| 'filename': filename | |
| } | |
| sorted_scenes = sorted(scene_sets.keys())[:3] | |
| for scene_idx, scene_num in enumerate(sorted_scenes): | |
| scene_data = scene_sets[scene_num] | |
| if 'original' not in scene_data: | |
| continue | |
| st.markdown(f"### Scene {scene_num}") | |
| cols = st.columns(3) | |
| with cols[0]: | |
| original = scene_data['original'] | |
| st.image(original['image_path'], use_container_width=True, caption="Original") | |
| cf_count = 0 | |
| for cf_key in ['cf1', 'cf2']: | |
| if cf_key in scene_data and cf_count < 2: | |
| cf_data = scene_data[cf_key] | |
| cf_scene_file = os.path.join(scenes_dir, cf_data['filename'].replace('.png', '.json')) | |
| cf_type = get_counterfactual_type_from_scene(cf_scene_file) if os.path.exists(cf_scene_file) else f"Counterfactual {cf_count + 1}" | |
| with cols[cf_count + 1]: | |
| st.image(cf_data['image_path'], use_container_width=True, caption=cf_type) | |
| cf_count += 1 | |
| if scene_idx < len(sorted_scenes) - 1: | |
| st.markdown("---") | |
| st.markdown("---") | |
| st.subheader("Download Output") | |
| zip_filename = f"counterfactual_output_{datetime.now().strftime('%Y%m%d_%H%M%S')}.zip" | |
| zip_path = os.path.join(tempfile.gettempdir(), zip_filename) | |
| try: | |
| create_zip_file(output_dir, zip_path) | |
| file_size = os.path.getsize(zip_path) / (1024 * 1024) | |
| with open(zip_path, 'rb') as f: | |
| st.download_button( | |
| label=f"Download as ZIP ({file_size:.2f} MB)", | |
| data=f.read(), | |
| file_name=zip_filename, | |
| mime="application/zip", | |
| use_container_width=True | |
| ) | |
| with st.expander("Output Structure"): | |
| st.text(f"Output directory: {output_dir}") | |
| if scene_files: | |
| st.text(f"\nScene files: {len(scene_files)}") | |
| st.text("Sample files:") | |
| for f in scene_files[:5]: | |
| st.text(f" - {f.name}") | |
| if csv_files: | |
| st.text(f"\nCSV files: {len(csv_files)}") | |
| for f in csv_files: | |
| st.text(f" - {f.name}") | |
| if image_files: | |
| st.text(f"\nImage files: {len(image_files)}") | |
| st.text("Sample files:") | |
| for f in image_files[:5]: | |
| st.text(f" - {f.name}") | |
| except Exception as e: | |
| st.error(f"Error creating zip file: {str(e)}") | |
| else: | |
| st.warning("Output directory not found.") | |
| else: | |
| st.info("Configure parameters and click 'Generate Counterfactual' to start.") | |
| st.markdown("---") | |
| st.markdown( | |
| "<div style='text-align: center; color: #666; padding: 1rem;'>" | |
| "Counterfactual Image Tool | Built with Streamlit" | |
| "</div>", | |
| unsafe_allow_html=True | |
| ) | |
| if __name__ == "__main__": | |
| main() | |