Spaces:
Sleeping
Sleeping
| import os | |
| # Force cache dirs to a writable location | |
| # os.environ["HF_HOME"] = "/home/user/.cache" # or "/tmp/hf_home" if you want ephemeral | |
| # os.environ["HF_HUB_CACHE"] = "/home/user/.cache/hub" | |
| # os.environ["TRANSFORMERS_CACHE"] = "/home/user/.cache/transformers" | |
| # os.environ["HF_DATASETS_CACHE"] = "/home/user/.cache/datasets" | |
| os.environ['HF_HOME'] = "/tmp/hf_cache" | |
| CACHE_DIR = "/tmp/hf_cache" | |
| import streamlit as st | |
| from huggingface_hub import hf_hub_download | |
| from datasets import load_dataset | |
| from PIL import Image | |
| import numpy as np | |
| import random | |
| import csv | |
| import re | |
| import datetime | |
| import cv2 | |
| import torch | |
| import torch.fft | |
| import smtplib | |
| import io | |
| from email.mime.multipart import MIMEMultipart | |
| from email.mime.text import MIMEText | |
| from email.mime.base import MIMEBase | |
| from email import encoders | |
| from skimage.metrics import structural_similarity as ssim, peak_signal_noise_ratio as psnr | |
| from torchmetrics.image import MultiScaleStructuralSimilarityIndexMeasure | |
| from datasets import Dataset | |
| from huggingface_hub import HfApi | |
| import pandas as pd | |
| from dotenv import load_dotenv | |
| import uuid | |
| load_dotenv() # will read .env and inject into os.environ | |
| def push_results_to_hub(user_id, responses): | |
| """ | |
| Push responses to a Hugging Face dataset repo | |
| Each user_id is stored as its own split | |
| """ | |
| # try: | |
| # Use token from secrets or env | |
| hf_token = os.environ.get("HF_TOKEN") | |
| repo_id = "rain-maker/image-study-results" # your dataset repo | |
| # if not hf_token: | |
| # st.error("β No Hugging Face token found. Add HF_TOKEN to Space secrets.") | |
| # return False | |
| # if not responses: | |
| # st.warning("β οΈ No responses to push.") | |
| # return False | |
| # Convert responses (list of dicts) into DataFrame | |
| df = pd.DataFrame(responses, columns=st.session_state.csv_headers) | |
| # # β Enforce your desired headers | |
| # if "csv_headers" in st.session_state: | |
| # headers = st.session_state.csv_headers | |
| # df = df.reindex(columns=headers) # reorder / add missing cols as NaN | |
| # df.columns = [str(c) for c in df.columns] # force all names to strings | |
| # Convert DataFrame to Hugging Face Dataset | |
| ds = Dataset.from_pandas(df, preserve_index=False) | |
| # Push to hub under user_id split | |
| ds.push_to_hub(repo_id, token=hf_token, split=user_id, private=True) | |
| # st.success(f"β Results for `{user_id}` pushed to {repo_id} (split = {user_id})") | |
| return True | |
| # except Exception as e: | |
| # st.error(f"β Failed to push results: {str(e)}") | |
| # return False | |
| # ICS Metric Implementation | |
| def kl_divergence(p, q, eps=1e-10): | |
| p = p + eps | |
| q = q + eps | |
| return torch.sum(p * torch.log(p / q), dim=(-1, -2)) | |
| def power_spectrum(x): | |
| fft = torch.fft.fft2(x) | |
| mag = torch.abs(fft) | |
| power = mag**2 | |
| return power / torch.sum(power, dim=(-1, -2), keepdim=True) | |
| def information_conservation_score(x, x_hat, lambda1=0.5): | |
| # Convert numpy to float32 tensors, normalize to [0, 1] | |
| x_t = torch.from_numpy(x).float() / 255.0 if x.dtype == np.uint8 else torch.from_numpy(x).float() | |
| x_hat_t = torch.from_numpy(x_hat).float() / 255.0 if x_hat.dtype == np.uint8 else torch.from_numpy(x_hat).float() | |
| # Reshape to [1, C, H, W] | |
| x_t = x_t.permute(2, 0, 1).unsqueeze(0) | |
| x_hat_t = x_hat_t.permute(2, 0, 1).unsqueeze(0) | |
| # MS-SSIM | |
| ms_ssim_metric = MultiScaleStructuralSimilarityIndexMeasure(data_range=1.0) | |
| ms_ssim_score = ms_ssim_metric(x_t, x_hat_t) | |
| # Frequency domain part | |
| p_spec = power_spectrum(torch.from_numpy(x)) | |
| q_spec = power_spectrum(torch.from_numpy(x_hat)) | |
| kl = kl_divergence(p_spec, q_spec) | |
| kl_normalized = kl / torch.log(torch.tensor(p_spec.numel(), dtype=torch.float32)) | |
| freq_score = 1.0 - kl_normalized | |
| return ((lambda1 * ms_ssim_score.mean() + (1 - lambda1) * freq_score.mean())).item() | |
| def load_my_dataset(): | |
| return load_dataset( | |
| "rain-maker/RAW-RAIN-sample", | |
| split="test", # or use "imagefolder" if repo is raw files | |
| repo_type="dataset" | |
| ) | |
| def get_image(scene, frame, kind): | |
| image = hf_hub_download( | |
| repo_id="rain-maker/RAW-RAIN-sample", | |
| filename=f"{kind}_test/{scene}/rgb_output/output_{frame}.png", # must match repo structure | |
| repo_type="dataset", | |
| # cache_dir="/home/user/.cache/huggingface" | |
| ) | |
| return Image.open(image) | |
| # def get_image(scene, frame, kind): | |
| # # Construct relative path exactly like in your hf_hub_download version | |
| # rel_path = f"{kind}_test/{scene}/rgb_output/output_{frame}.png" | |
| # # Find the entry with that filename | |
| # record = next((item for item in ds if item["image"].filename.endswith(rel_path)), None) | |
| # if record is None: | |
| # raise FileNotFoundError(f"{rel_path} not found in dataset") | |
| # return record["image"] # already a PIL.Image | |
| def validate_inputs(full_name): | |
| """Validate user inputs""" | |
| if not full_name.strip(): | |
| return False, "Please enter your full name." | |
| return True, "" | |
| def random_crop(gt_array, a_array, b_array, crop_size=512): | |
| """Create a random crop from an image array""" | |
| h, w = gt_array.shape[:2] | |
| y = random.randint(0, h - crop_size) | |
| x = random.randint(0, w - crop_size) | |
| return gt_array[y:y+crop_size, x:x+crop_size], a_array[y:y+crop_size, x:x+crop_size], b_array[y:y+crop_size, x:x+crop_size] | |
| def compute_metrics(gt_crop, a_crop, b_crop, ics_lambda=0.5): | |
| """Compute all metrics for the cropped images""" | |
| try: | |
| # SSIM metrics | |
| ssim_a = ssim(gt_crop, a_crop, channel_axis=-1, data_range=255, multichannel=True) | |
| ssim_b = ssim(gt_crop, b_crop, channel_axis=-1, data_range=255, multichannel=True) | |
| # PSNR metrics | |
| psnr_a = psnr(gt_crop, a_crop, data_range=255) | |
| psnr_b = psnr(gt_crop, b_crop, data_range=255) | |
| # ICS metrics with different lambda values | |
| ics_a = information_conservation_score(gt_crop, a_crop, lambda1=ics_lambda) | |
| ics_b = information_conservation_score(gt_crop, b_crop, lambda1=ics_lambda) | |
| ics_a_0 = information_conservation_score(gt_crop, a_crop, lambda1=0) | |
| ics_b_0 = information_conservation_score(gt_crop, b_crop, lambda1=0) | |
| ics_a_1 = information_conservation_score(gt_crop, a_crop, lambda1=1) | |
| ics_b_1 = information_conservation_score(gt_crop, b_crop, lambda1=1) | |
| # Additional flags | |
| ics_metric_unique = (ics_a > ics_b and (ssim_a < ssim_b or psnr_a < psnr_b)) | |
| bayer_ics_greater = ics_a > ics_b | |
| bayer_ssim_greater = ssim_a > ssim_b | |
| bayer_psnr_greater = psnr_a > psnr_b | |
| return { | |
| 'ssim_a': ssim_a, 'ssim_b': ssim_b, | |
| 'psnr_a': psnr_a, 'psnr_b': psnr_b, | |
| 'ics_a': ics_a, 'ics_b': ics_b, | |
| 'ics_a_0': ics_a_0, 'ics_b_0': ics_b_0, | |
| 'ics_a_1': ics_a_1, 'ics_b_1': ics_b_1, | |
| 'ics_metric_unique': ics_metric_unique, | |
| 'bayer_ics_greater': bayer_ics_greater, | |
| 'bayer_ssim_greater': bayer_ssim_greater, | |
| 'bayer_psnr_greater': bayer_psnr_greater | |
| } | |
| except Exception as e: | |
| st.error(f"Error computing metrics: {str(e)}") | |
| return None | |
| # Initialize session state | |
| if "user_authenticated" not in st.session_state: | |
| st.session_state.user_authenticated = False | |
| if "study_completed" not in st.session_state: | |
| st.session_state.study_completed = False | |
| if "full_name" not in st.session_state: | |
| st.session_state.full_name = "" | |
| if "user_id" not in st.session_state: | |
| st.session_state.user_id = "" | |
| if "csv_filename" not in st.session_state: | |
| st.session_state.csv_filename = "" | |
| if "index" not in st.session_state: | |
| st.session_state.index = 0 | |
| if "current_scene" not in st.session_state: | |
| st.session_state.current_scene = None | |
| if "current_frame" not in st.session_state: | |
| st.session_state.current_frame = None | |
| if "current_crops" not in st.session_state: | |
| st.session_state.current_crops = None | |
| if "swap_order" not in st.session_state: | |
| st.session_state.swap_order = False | |
| if "responses_data" not in st.session_state: | |
| st.session_state.responses_data = [] | |
| if "email_address" not in st.session_state: | |
| st.session_state.email_address = "" | |
| if "auto_email" not in st.session_state: | |
| st.session_state.auto_email = False | |
| if "target_responses" not in st.session_state: | |
| st.session_state.target_responses = 25 | |
| if "csv_headers" not in st.session_state: | |
| st.session_state.csv_headers = [ | |
| "full_name", "user_id", "image_set", "scene", "frame", "choice", "timestamp", | |
| "ics_metric_unique", "bayer_ics_greater", "bayer_ssim_greater", "bayer_psnr_greater", | |
| "bayer_ics", "rgb_ics", "bayer_ics_0", "rgb_ics_0", "bayer_ics_1", "rgb_ics_1", | |
| "bayer_ssim", "rgb_ssim", "bayer_psnr", "rgb_psnr" | |
| ] | |
| # Main app logic | |
| st.title("Image Comparison Study with Metrics") | |
| if not st.session_state.user_authenticated: | |
| # User registration form | |
| st.header("User Information") | |
| st.write("Please enter your information to begin the image comparison study.") | |
| with st.form("user_info_form"): | |
| full_name = st.text_input("Full Name:", placeholder="Enter your full name") | |
| # user_id = st.text_input("ID Number:", placeholder="Enter your ID number") | |
| user_id = str(uuid.uuid4()).replace("-", "") | |
| submitted = st.form_submit_button("Start Study") | |
| if submitted: | |
| is_valid, error_message = validate_inputs(full_name) | |
| if is_valid: | |
| # Store user information - NO FILE OPERATIONS | |
| st.session_state.full_name = full_name.strip() | |
| st.session_state.user_id = user_id | |
| st.session_state.csv_filename = f"{st.session_state.user_id}.csv" | |
| st.session_state.user_authenticated = True | |
| st.success(f"Welcome, {st.session_state.full_name}! Your responses will be stored in session.") | |
| st.rerun() | |
| else: | |
| st.error(error_message) | |
| else: | |
| ds = load_my_dataset() | |
| # Main image comparison interface | |
| st.header(f"Welcome, {st.session_state.full_name}") | |
| st.write(f"User ID: {st.session_state.user_id}") | |
| progress_text = f"Responses collected: {len(st.session_state.responses_data)}" | |
| st.write(progress_text) | |
| # Generate or load current image set | |
| if st.session_state.current_scene is None or st.button("Load New Image Set", key="load_new"): | |
| # scene = random.choice(SCENE_NAMES) | |
| # scene = random.randint(1, 10) | |
| scene = random.choice([1, 2, 3, 4, 5, 8, 9, 10]) | |
| frame = random.randint(0, 300) | |
| swap_order = random.choice([True, False]) | |
| st.session_state.current_scene = f"test_scene_{scene}" | |
| st.session_state.current_frame = frame | |
| st.session_state.swap_order = swap_order | |
| st.session_state.current_crops = None # Reset crops | |
| try: | |
| if st.session_state.study_completed: | |
| st.success("π Study completed!") | |
| st.balloons() | |
| else: | |
| # Load images | |
| with st.spinner("Loading images..."): | |
| gt = get_image(st.session_state.current_scene, 1, "gt") | |
| a = get_image(st.session_state.current_scene, st.session_state.current_frame, "bayer") # Bayer | |
| b = get_image(st.session_state.current_scene, st.session_state.current_frame, "rgb") # RGB | |
| # Convert to numpy arrays | |
| gt_array = np.array(gt) | |
| a_array = np.array(a) | |
| b_array = np.array(b) | |
| # Generate random crops if not already done | |
| if st.session_state.current_crops is None: | |
| gt_crop, a_crop, b_crop = random_crop(gt_array, a_array, b_array, crop_size=512) | |
| st.session_state.current_crops = { | |
| 'gt': gt_crop, | |
| 'a': a_crop, | |
| 'b': b_crop | |
| } | |
| else: | |
| gt_crop = st.session_state.current_crops['gt'] | |
| a_crop = st.session_state.current_crops['a'] | |
| b_crop = st.session_state.current_crops['b'] | |
| # st.write(f"### Image set {st.session_state.index + 1}/{st.session_state.target_responses}") | |
| # Display images | |
| cols = st.columns(3) | |
| if st.session_state.swap_order: | |
| # Swapped order: B, GT, A | |
| cols[0].image(b_crop, caption="Image A", use_container_width=True) | |
| cols[1].image(gt_crop, caption="Ground Truth", use_container_width=True) | |
| cols[2].image(a_crop, caption="Image B", use_container_width=True) | |
| else: | |
| # Normal order: A, GT, B | |
| cols[0].image(a_crop, caption="Image A", use_container_width=True) | |
| cols[1].image(gt_crop, caption="Ground Truth", use_container_width=True) | |
| cols[2].image(b_crop, caption="Image B", use_container_width=True) | |
| col1, col2 = st.columns([1, 1]) | |
| with col1: | |
| if st.button("Image A", type="primary"): | |
| # Determine actual choice based on swap | |
| if st.session_state.swap_order: | |
| actual_choice = "RGB" | |
| else: | |
| actual_choice = "Bayer" | |
| # Compute metrics | |
| with st.spinner("Computing metrics..."): | |
| metrics = compute_metrics(gt_crop, a_crop, b_crop, ics_lambda=0.5) | |
| if metrics: | |
| # Save choice and metrics to session state - NO FILE OPERATIONS | |
| timestamp = datetime.datetime.now().isoformat() | |
| response_data = [ | |
| st.session_state.full_name, | |
| st.session_state.user_id, | |
| st.session_state.index + 1, | |
| st.session_state.current_scene, | |
| st.session_state.current_frame, | |
| actual_choice, | |
| timestamp, | |
| metrics['ics_metric_unique'], | |
| metrics['bayer_ics_greater'], | |
| metrics['bayer_ssim_greater'], | |
| metrics['bayer_psnr_greater'], | |
| metrics['ics_a'], | |
| metrics['ics_b'], | |
| metrics['ics_a_0'], | |
| metrics['ics_b_0'], | |
| metrics['ics_a_1'], | |
| metrics['ics_b_1'], | |
| metrics['ssim_a'], | |
| metrics['ssim_b'], | |
| metrics['psnr_a'], | |
| metrics['psnr_b'] | |
| ] | |
| st.session_state.responses_data.append(response_data) | |
| st.session_state.index += 1 | |
| st.session_state.current_scene = None # Force new image set | |
| # Create CSV content | |
| output = io.StringIO() | |
| writer = csv.writer(output) | |
| writer.writerow(st.session_state.csv_headers) | |
| writer.writerows(st.session_state.responses_data) | |
| csv_content = output.getvalue() | |
| # Send email | |
| with st.spinner("Storing selection..."): | |
| results_pushed = push_results_to_hub( | |
| user_id=st.session_state.user_id, | |
| responses=st.session_state.responses_data | |
| ) | |
| # Check if study is completed and auto-email is enabled | |
| if (len(st.session_state.responses_data) >= st.session_state.target_responses): | |
| st.session_state.study_completed = True | |
| st.success("Choice and metrics recorded!") | |
| st.rerun() | |
| with col2: | |
| if st.button("Image B", type="primary"): | |
| # Determine actual choice based on swap | |
| if st.session_state.swap_order: | |
| actual_choice = "Bayer" | |
| else: | |
| actual_choice = "RGB" | |
| # Compute metrics | |
| with st.spinner("Computing metrics..."): | |
| metrics = compute_metrics(gt_crop, a_crop, b_crop, ics_lambda=0.5) | |
| if metrics: | |
| # Save choice and metrics to session state - NO FILE OPERATIONS | |
| timestamp = datetime.datetime.now().isoformat() | |
| response_data = [ | |
| st.session_state.full_name, | |
| st.session_state.user_id, | |
| st.session_state.index + 1, | |
| st.session_state.current_scene, | |
| st.session_state.current_frame, | |
| actual_choice, | |
| timestamp, | |
| metrics['ics_metric_unique'], | |
| metrics['bayer_ics_greater'], | |
| metrics['bayer_ssim_greater'], | |
| metrics['bayer_psnr_greater'], | |
| metrics['ics_a'], | |
| metrics['ics_b'], | |
| metrics['ics_a_0'], | |
| metrics['ics_b_0'], | |
| metrics['ics_a_1'], | |
| metrics['ics_b_1'], | |
| metrics['ssim_a'], | |
| metrics['ssim_b'], | |
| metrics['psnr_a'], | |
| metrics['psnr_b'] | |
| ] | |
| st.session_state.responses_data.append(response_data) | |
| st.session_state.index += 1 | |
| st.session_state.current_scene = None # Force new image set | |
| # Create CSV content | |
| output = io.StringIO() | |
| writer = csv.writer(output) | |
| writer.writerow(st.session_state.csv_headers) | |
| writer.writerows(st.session_state.responses_data) | |
| csv_content = output.getvalue() | |
| # Send email | |
| with st.spinner("Storing selection..."): | |
| results_pushed = push_results_to_hub( | |
| user_id=st.session_state.user_id, | |
| responses=st.session_state.responses_data | |
| ) | |
| # Check if study is completed and auto-email is enabled | |
| if (len(st.session_state.responses_data) >= st.session_state.target_responses): | |
| st.session_state.study_completed = True | |
| st.success("Choice and metrics recorded!") | |
| st.rerun() | |
| st.write(f"Completed comparisons: {st.session_state.index}") | |
| except Exception as e: | |
| # st.success("π Study ended early!") | |
| # st.session_state.study_completed = True | |
| # st.rerun() | |
| st.error(f"Error loading images: {str(e)}") | |
| st.write("This might be due to network issues, the Hugging Face repository being unavailable, or missing dependencies.") | |
| st.write("Make sure you have the following packages installed:") | |
| st.code("pip install torch torchvision torchmetrics scikit-image opencv-python") |