import os # Force cache dirs to a writable location # os.environ["HF_HOME"] = "/home/user/.cache" # or "/tmp/hf_home" if you want ephemeral # os.environ["HF_HUB_CACHE"] = "/home/user/.cache/hub" # os.environ["TRANSFORMERS_CACHE"] = "/home/user/.cache/transformers" # os.environ["HF_DATASETS_CACHE"] = "/home/user/.cache/datasets" os.environ['HF_HOME'] = "/tmp/hf_cache" CACHE_DIR = "/tmp/hf_cache" import streamlit as st from huggingface_hub import hf_hub_download from datasets import load_dataset from PIL import Image import numpy as np import random import csv import re import datetime import cv2 import torch import torch.fft import smtplib import io from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText from email.mime.base import MIMEBase from email import encoders from skimage.metrics import structural_similarity as ssim, peak_signal_noise_ratio as psnr from torchmetrics.image import MultiScaleStructuralSimilarityIndexMeasure from datasets import Dataset from huggingface_hub import HfApi import pandas as pd from dotenv import load_dotenv import uuid load_dotenv() # will read .env and inject into os.environ def push_results_to_hub(user_id, responses): """ Push responses to a Hugging Face dataset repo Each user_id is stored as its own split """ # try: # Use token from secrets or env hf_token = os.environ.get("HF_TOKEN") repo_id = "rain-maker/image-study-results" # your dataset repo # if not hf_token: # st.error("❌ No Hugging Face token found. Add HF_TOKEN to Space secrets.") # return False # if not responses: # st.warning("⚠️ No responses to push.") # return False # Convert responses (list of dicts) into DataFrame df = pd.DataFrame(responses, columns=st.session_state.csv_headers) # # ✅ Enforce your desired headers # if "csv_headers" in st.session_state: # headers = st.session_state.csv_headers # df = df.reindex(columns=headers) # reorder / add missing cols as NaN # df.columns = [str(c) for c in df.columns] # force all names to strings # Convert DataFrame to Hugging Face Dataset ds = Dataset.from_pandas(df, preserve_index=False) # Push to hub under user_id split ds.push_to_hub(repo_id, token=hf_token, split=user_id, private=True) # st.success(f"✅ Results for `{user_id}` pushed to {repo_id} (split = {user_id})") return True # except Exception as e: # st.error(f"❌ Failed to push results: {str(e)}") # return False # ICS Metric Implementation def kl_divergence(p, q, eps=1e-10): p = p + eps q = q + eps return torch.sum(p * torch.log(p / q), dim=(-1, -2)) def power_spectrum(x): fft = torch.fft.fft2(x) mag = torch.abs(fft) power = mag**2 return power / torch.sum(power, dim=(-1, -2), keepdim=True) def information_conservation_score(x, x_hat, lambda1=0.5): # Convert numpy to float32 tensors, normalize to [0, 1] x_t = torch.from_numpy(x).float() / 255.0 if x.dtype == np.uint8 else torch.from_numpy(x).float() x_hat_t = torch.from_numpy(x_hat).float() / 255.0 if x_hat.dtype == np.uint8 else torch.from_numpy(x_hat).float() # Reshape to [1, C, H, W] x_t = x_t.permute(2, 0, 1).unsqueeze(0) x_hat_t = x_hat_t.permute(2, 0, 1).unsqueeze(0) # MS-SSIM ms_ssim_metric = MultiScaleStructuralSimilarityIndexMeasure(data_range=1.0) ms_ssim_score = ms_ssim_metric(x_t, x_hat_t) # Frequency domain part p_spec = power_spectrum(torch.from_numpy(x)) q_spec = power_spectrum(torch.from_numpy(x_hat)) kl = kl_divergence(p_spec, q_spec) kl_normalized = kl / torch.log(torch.tensor(p_spec.numel(), dtype=torch.float32)) freq_score = 1.0 - kl_normalized return ((lambda1 * ms_ssim_score.mean() + (1 - lambda1) * freq_score.mean())).item() @st.cache_resource def load_my_dataset(): return load_dataset( "rain-maker/RAW-RAIN-sample", split="test", # or use "imagefolder" if repo is raw files repo_type="dataset" ) def get_image(scene, frame, kind): image = hf_hub_download( repo_id="rain-maker/RAW-RAIN-sample", filename=f"{kind}_test/{scene}/rgb_output/output_{frame}.png", # must match repo structure repo_type="dataset", # cache_dir="/home/user/.cache/huggingface" ) return Image.open(image) # def get_image(scene, frame, kind): # # Construct relative path exactly like in your hf_hub_download version # rel_path = f"{kind}_test/{scene}/rgb_output/output_{frame}.png" # # Find the entry with that filename # record = next((item for item in ds if item["image"].filename.endswith(rel_path)), None) # if record is None: # raise FileNotFoundError(f"{rel_path} not found in dataset") # return record["image"] # already a PIL.Image def validate_inputs(full_name): """Validate user inputs""" if not full_name.strip(): return False, "Please enter your full name." return True, "" def random_crop(gt_array, a_array, b_array, crop_size=512): """Create a random crop from an image array""" h, w = gt_array.shape[:2] y = random.randint(0, h - crop_size) x = random.randint(0, w - crop_size) return gt_array[y:y+crop_size, x:x+crop_size], a_array[y:y+crop_size, x:x+crop_size], b_array[y:y+crop_size, x:x+crop_size] def compute_metrics(gt_crop, a_crop, b_crop, ics_lambda=0.5): """Compute all metrics for the cropped images""" try: # SSIM metrics ssim_a = ssim(gt_crop, a_crop, channel_axis=-1, data_range=255, multichannel=True) ssim_b = ssim(gt_crop, b_crop, channel_axis=-1, data_range=255, multichannel=True) # PSNR metrics psnr_a = psnr(gt_crop, a_crop, data_range=255) psnr_b = psnr(gt_crop, b_crop, data_range=255) # ICS metrics with different lambda values ics_a = information_conservation_score(gt_crop, a_crop, lambda1=ics_lambda) ics_b = information_conservation_score(gt_crop, b_crop, lambda1=ics_lambda) ics_a_0 = information_conservation_score(gt_crop, a_crop, lambda1=0) ics_b_0 = information_conservation_score(gt_crop, b_crop, lambda1=0) ics_a_1 = information_conservation_score(gt_crop, a_crop, lambda1=1) ics_b_1 = information_conservation_score(gt_crop, b_crop, lambda1=1) # Additional flags ics_metric_unique = (ics_a > ics_b and (ssim_a < ssim_b or psnr_a < psnr_b)) bayer_ics_greater = ics_a > ics_b bayer_ssim_greater = ssim_a > ssim_b bayer_psnr_greater = psnr_a > psnr_b return { 'ssim_a': ssim_a, 'ssim_b': ssim_b, 'psnr_a': psnr_a, 'psnr_b': psnr_b, 'ics_a': ics_a, 'ics_b': ics_b, 'ics_a_0': ics_a_0, 'ics_b_0': ics_b_0, 'ics_a_1': ics_a_1, 'ics_b_1': ics_b_1, 'ics_metric_unique': ics_metric_unique, 'bayer_ics_greater': bayer_ics_greater, 'bayer_ssim_greater': bayer_ssim_greater, 'bayer_psnr_greater': bayer_psnr_greater } except Exception as e: st.error(f"Error computing metrics: {str(e)}") return None # Initialize session state if "user_authenticated" not in st.session_state: st.session_state.user_authenticated = False if "study_completed" not in st.session_state: st.session_state.study_completed = False if "full_name" not in st.session_state: st.session_state.full_name = "" if "user_id" not in st.session_state: st.session_state.user_id = "" if "csv_filename" not in st.session_state: st.session_state.csv_filename = "" if "index" not in st.session_state: st.session_state.index = 0 if "current_scene" not in st.session_state: st.session_state.current_scene = None if "current_frame" not in st.session_state: st.session_state.current_frame = None if "current_crops" not in st.session_state: st.session_state.current_crops = None if "swap_order" not in st.session_state: st.session_state.swap_order = False if "responses_data" not in st.session_state: st.session_state.responses_data = [] if "email_address" not in st.session_state: st.session_state.email_address = "" if "auto_email" not in st.session_state: st.session_state.auto_email = False if "target_responses" not in st.session_state: st.session_state.target_responses = 25 if "csv_headers" not in st.session_state: st.session_state.csv_headers = [ "full_name", "user_id", "image_set", "scene", "frame", "choice", "timestamp", "ics_metric_unique", "bayer_ics_greater", "bayer_ssim_greater", "bayer_psnr_greater", "bayer_ics", "rgb_ics", "bayer_ics_0", "rgb_ics_0", "bayer_ics_1", "rgb_ics_1", "bayer_ssim", "rgb_ssim", "bayer_psnr", "rgb_psnr" ] # Main app logic st.title("Image Comparison Study with Metrics") if not st.session_state.user_authenticated: # User registration form st.header("User Information") st.write("Please enter your information to begin the image comparison study.") with st.form("user_info_form"): full_name = st.text_input("Full Name:", placeholder="Enter your full name") # user_id = st.text_input("ID Number:", placeholder="Enter your ID number") user_id = str(uuid.uuid4()).replace("-", "") submitted = st.form_submit_button("Start Study") if submitted: is_valid, error_message = validate_inputs(full_name) if is_valid: # Store user information - NO FILE OPERATIONS st.session_state.full_name = full_name.strip() st.session_state.user_id = user_id st.session_state.csv_filename = f"{st.session_state.user_id}.csv" st.session_state.user_authenticated = True st.success(f"Welcome, {st.session_state.full_name}! Your responses will be stored in session.") st.rerun() else: st.error(error_message) else: ds = load_my_dataset() # Main image comparison interface st.header(f"Welcome, {st.session_state.full_name}") st.write(f"User ID: {st.session_state.user_id}") progress_text = f"Responses collected: {len(st.session_state.responses_data)}" st.write(progress_text) # Generate or load current image set if st.session_state.current_scene is None or st.button("Load New Image Set", key="load_new"): # scene = random.choice(SCENE_NAMES) # scene = random.randint(1, 10) scene = random.choice([1, 2, 3, 4, 5, 8, 9, 10]) frame = random.randint(0, 300) swap_order = random.choice([True, False]) st.session_state.current_scene = f"test_scene_{scene}" st.session_state.current_frame = frame st.session_state.swap_order = swap_order st.session_state.current_crops = None # Reset crops try: if st.session_state.study_completed: st.success("🎉 Study completed!") st.balloons() else: # Load images with st.spinner("Loading images..."): gt = get_image(st.session_state.current_scene, 1, "gt") a = get_image(st.session_state.current_scene, st.session_state.current_frame, "bayer") # Bayer b = get_image(st.session_state.current_scene, st.session_state.current_frame, "rgb") # RGB # Convert to numpy arrays gt_array = np.array(gt) a_array = np.array(a) b_array = np.array(b) # Generate random crops if not already done if st.session_state.current_crops is None: gt_crop, a_crop, b_crop = random_crop(gt_array, a_array, b_array, crop_size=512) st.session_state.current_crops = { 'gt': gt_crop, 'a': a_crop, 'b': b_crop } else: gt_crop = st.session_state.current_crops['gt'] a_crop = st.session_state.current_crops['a'] b_crop = st.session_state.current_crops['b'] # st.write(f"### Image set {st.session_state.index + 1}/{st.session_state.target_responses}") # Display images cols = st.columns(3) if st.session_state.swap_order: # Swapped order: B, GT, A cols[0].image(b_crop, caption="Image A", use_container_width=True) cols[1].image(gt_crop, caption="Ground Truth", use_container_width=True) cols[2].image(a_crop, caption="Image B", use_container_width=True) else: # Normal order: A, GT, B cols[0].image(a_crop, caption="Image A", use_container_width=True) cols[1].image(gt_crop, caption="Ground Truth", use_container_width=True) cols[2].image(b_crop, caption="Image B", use_container_width=True) col1, col2 = st.columns([1, 1]) with col1: if st.button("Image A", type="primary"): # Determine actual choice based on swap if st.session_state.swap_order: actual_choice = "RGB" else: actual_choice = "Bayer" # Compute metrics with st.spinner("Computing metrics..."): metrics = compute_metrics(gt_crop, a_crop, b_crop, ics_lambda=0.5) if metrics: # Save choice and metrics to session state - NO FILE OPERATIONS timestamp = datetime.datetime.now().isoformat() response_data = [ st.session_state.full_name, st.session_state.user_id, st.session_state.index + 1, st.session_state.current_scene, st.session_state.current_frame, actual_choice, timestamp, metrics['ics_metric_unique'], metrics['bayer_ics_greater'], metrics['bayer_ssim_greater'], metrics['bayer_psnr_greater'], metrics['ics_a'], metrics['ics_b'], metrics['ics_a_0'], metrics['ics_b_0'], metrics['ics_a_1'], metrics['ics_b_1'], metrics['ssim_a'], metrics['ssim_b'], metrics['psnr_a'], metrics['psnr_b'] ] st.session_state.responses_data.append(response_data) st.session_state.index += 1 st.session_state.current_scene = None # Force new image set # Create CSV content output = io.StringIO() writer = csv.writer(output) writer.writerow(st.session_state.csv_headers) writer.writerows(st.session_state.responses_data) csv_content = output.getvalue() # Send email with st.spinner("Storing selection..."): results_pushed = push_results_to_hub( user_id=st.session_state.user_id, responses=st.session_state.responses_data ) # Check if study is completed and auto-email is enabled if (len(st.session_state.responses_data) >= st.session_state.target_responses): st.session_state.study_completed = True st.success("Choice and metrics recorded!") st.rerun() with col2: if st.button("Image B", type="primary"): # Determine actual choice based on swap if st.session_state.swap_order: actual_choice = "Bayer" else: actual_choice = "RGB" # Compute metrics with st.spinner("Computing metrics..."): metrics = compute_metrics(gt_crop, a_crop, b_crop, ics_lambda=0.5) if metrics: # Save choice and metrics to session state - NO FILE OPERATIONS timestamp = datetime.datetime.now().isoformat() response_data = [ st.session_state.full_name, st.session_state.user_id, st.session_state.index + 1, st.session_state.current_scene, st.session_state.current_frame, actual_choice, timestamp, metrics['ics_metric_unique'], metrics['bayer_ics_greater'], metrics['bayer_ssim_greater'], metrics['bayer_psnr_greater'], metrics['ics_a'], metrics['ics_b'], metrics['ics_a_0'], metrics['ics_b_0'], metrics['ics_a_1'], metrics['ics_b_1'], metrics['ssim_a'], metrics['ssim_b'], metrics['psnr_a'], metrics['psnr_b'] ] st.session_state.responses_data.append(response_data) st.session_state.index += 1 st.session_state.current_scene = None # Force new image set # Create CSV content output = io.StringIO() writer = csv.writer(output) writer.writerow(st.session_state.csv_headers) writer.writerows(st.session_state.responses_data) csv_content = output.getvalue() # Send email with st.spinner("Storing selection..."): results_pushed = push_results_to_hub( user_id=st.session_state.user_id, responses=st.session_state.responses_data ) # Check if study is completed and auto-email is enabled if (len(st.session_state.responses_data) >= st.session_state.target_responses): st.session_state.study_completed = True st.success("Choice and metrics recorded!") st.rerun() st.write(f"Completed comparisons: {st.session_state.index}") except Exception as e: # st.success("🎉 Study ended early!") # st.session_state.study_completed = True # st.rerun() st.error(f"Error loading images: {str(e)}") st.write("This might be due to network issues, the Hugging Face repository being unavailable, or missing dependencies.") st.write("Make sure you have the following packages installed:") st.code("pip install torch torchvision torchmetrics scikit-image opencv-python")