Spaces:
Build error
Build error
| # -*- coding: utf-8 -*- | |
| """ | |
| Created on Mon Apr 7 13:43:34 2025 | |
| @author: camaac | |
| """ | |
| import streamlit as st | |
| import os | |
| import random | |
| import pandas as pd | |
| from PIL import Image, ImageEnhance | |
| import numpy as np | |
| import gspread | |
| from oauth2client.service_account import ServiceAccountCredentials | |
| from skimage.exposure import match_histograms | |
| import time | |
| from streamlit_autorefresh import st_autorefresh | |
| # ------------------------- | |
| # Global parameters | |
| # ------------------------- | |
| IMAGE_DIR = "images" # Folder containing images | |
| NUM_PAIRS = 25 # Total number of pairs to be assessed | |
| RESULTS_FILE = "results.csv" # CSV file for saving responses | |
| # ------------------------- | |
| # Helper functions | |
| # ------------------------- | |
| def load_image_pair(index): | |
| """ | |
| For a given index (integer), returns the path of the ground truth and the path of AI generated image. | |
| Files are named with a 5-digit index. | |
| """ | |
| idx_str = str(index).zfill(5) | |
| gt_path = os.path.join(IMAGE_DIR, f"{idx_str}.png") | |
| pred_path = os.path.join(IMAGE_DIR, f"{idx_str}_gen0.png") | |
| return gt_path, pred_path | |
| def match_brightness(source_img, target_img): | |
| source_brightness = np.mean(np.array(source_img)) | |
| target_brightness = np.mean(np.array(target_img)) | |
| if target_brightness == 0: | |
| factor = 1 # avoid division by zero | |
| else: | |
| factor = source_brightness / target_brightness | |
| enhancer = ImageEnhance.Brightness(target_img) | |
| adjusted = enhancer.enhance(factor) | |
| return adjusted | |
| def match_histograms_pil(img_reference, img_to_adjust): | |
| """ | |
| Layer the histogram of `img_reference` on `img_to_adjust` | |
| (both images are PIL.Image objects). | |
| Returns a PIL image with adjusted histogram. | |
| """ | |
| # Convertir les deux images en tableaux numpy | |
| ref_array = np.array(img_reference) | |
| adj_array = np.array(img_to_adjust) | |
| # Ajuster l'histogramme | |
| matched = match_histograms(adj_array, ref_array, channel_axis=-1) | |
| # Reconvertir en image PIL | |
| matched_img = Image.fromarray(np.uint8(matched)) | |
| return matched_img | |
| # ------------------------- | |
| # Navigation via st.session_state | |
| # ------------------------- | |
| if "page" not in st.session_state: | |
| st.session_state.page = "intro" | |
| if "user_name" not in st.session_state: | |
| st.session_state.user_name = "" | |
| if "current_index" not in st.session_state: | |
| st.session_state.current_index = 0 | |
| if "results" not in st.session_state: | |
| st.session_state.results = [] | |
| if "list_pair" not in st.session_state: | |
| st.session_state.list_pair = [] | |
| if "list_pair_ID" not in st.session_state: | |
| st.session_state.list_pair_ID = [] | |
| if "results_tot" not in st.session_state: | |
| st.session_state.results_tot = 0 | |
| if "submitted" not in st.session_state: | |
| st.session_state.submitted = False | |
| # ------------------------- | |
| # Intro page | |
| # ------------------------- | |
| if st.session_state.page == "intro": | |
| st.title("AI Wood Generation Evaluation Study") | |
| st.markdown( | |
| """ | |
| **Welcome!** | |
| In this study, you will be shown pairs of wood surface images. | |
| One image is a real photograph and the other is generated by AI. | |
| Your task is to select the image you believe is **real**. | |
| ⌛ *Each image pair will be visible for 10 seconds only, be quick!* ⌛ | |
| Please enter your name below and click **Start Evaluation** to begin. | |
| """ | |
| ) | |
| name = st.text_input("Enter your name:") | |
| if st.button("Start Evaluation") and name: | |
| st.session_state.user_name = name | |
| st.session_state.page = "evaluation" | |
| st.rerun() | |
| st.session_state.list_pair_ID = random.sample(range(1, 51), NUM_PAIRS) | |
| for i, index in enumerate(st.session_state.list_pair_ID): | |
| gt_path, pred_path = load_image_pair(index) | |
| pair = [("GT", gt_path), ("Pred", pred_path)] | |
| random.shuffle(pair) | |
| st.session_state.list_pair.append(pair) | |
| st.stop() | |
| # ------------------------- | |
| # Evaluation page | |
| # ------------------------- | |
| if st.session_state.page == "evaluation": | |
| st.title("AI Wood Generation Evaluation") | |
| # st.write(f"User: **{st.session_state.user_name}**") | |
| if "start_time" not in st.session_state or st.session_state.page_changed: | |
| st.session_state.start_time = time.time() | |
| st.session_state.page_changed = False | |
| # If all pairs have been evaluated, display a message and save the results | |
| if st.session_state.current_index+1 > NUM_PAIRS: | |
| st.markdown("<h4>How confident were you in your answers?</h4>", unsafe_allow_html=True) | |
| confidence = st.radio( #st.select_slider | |
| " ", | |
| [ | |
| "Not confident at all", | |
| "Slightly confident", | |
| "Moderately confident", | |
| "Very confident", | |
| "Extremely confident" | |
| ], | |
| index=2, #value="Moderately confident" | |
| horizontal=True | |
| ) | |
| if st.button("Submit"): | |
| #Calculating result | |
| correct_guess = np.array(st.session_state.results) | |
| nb_correct = np.sum(correct_guess) | |
| st.session_state.results_tot = nb_correct | |
| st.success(f"Number of correct answers: {nb_correct}/{NUM_PAIRS}") | |
| st.success("Thank you for completing the evaluation!", icon=":material/park:") | |
| #Save result | |
| scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive'] | |
| creds = ServiceAccountCredentials.from_json_keyfile_name('glass-flux-456209-d4-6fc4b7d9d274.json', scope) | |
| client = gspread.authorize(creds) | |
| sh = client.open('Results_woodAI').worksheet('test') | |
| row = [ | |
| st.session_state.user_name, | |
| confidence, | |
| str(nb_correct), | |
| ",".join(map(str, st.session_state.list_pair_ID)), | |
| ",".join(map(str, correct_guess)), | |
| ] | |
| sh.append_row(row) | |
| st.session_state.submitted = True | |
| # if st.button("See detailed results"): | |
| # st.session_state.page = "detailed_results" | |
| # st.rerun() | |
| if st.session_state.submitted == True: | |
| if st.button("See detailed results"): | |
| st.session_state.page = "detailed_results" | |
| st.rerun() | |
| # st.stop() | |
| st.stop() | |
| st_autorefresh(interval=1000, key=f"timer_{st.session_state.current_index}") | |
| st.write(f"Image Pair {st.session_state.current_index+1} of {NUM_PAIRS}") | |
| # Charger et mélanger la paire pour l'index courant | |
| pair = st.session_state.list_pair[st.session_state.current_index] | |
| img1 = Image.open(pair[0][1]) | |
| img1 = img1.convert("L") | |
| img2 = Image.open(pair[1][1]) | |
| img2 = img2.convert("L") | |
| # if pair[0][0] == "GT": | |
| # img2 = match_brightness(img1, img2) | |
| # else: | |
| # img1 = match_brightness(img2, img1) | |
| elapsed = time.time() - st.session_state.start_time | |
| remaining = max(0, 10 - int(elapsed)) | |
| st.markdown(f"**Time remaining:** {remaining} seconds") | |
| percent = int((remaining / 10) * 100) | |
| st.progress(percent) | |
| placeholder = Image.new("L", img1.size, 128) | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| if elapsed < 10: | |
| st.image(img1, caption="Image 1", use_container_width=True) | |
| else: | |
| st.image(placeholder, caption="Time’s up!", use_container_width=True) | |
| with col2: | |
| if elapsed < 10: | |
| st.image(img2, caption="Image 2", use_container_width=True) | |
| else: | |
| st.image(placeholder, caption="Time’s up!", use_container_width=True) | |
| choice = st.radio("Select the real image: ", options=["1", "2"], horizontal = True) #, index=None | |
| if st.button("Next"): | |
| if (choice == "1" and pair[0][0] == "GT") or (choice == "2" and pair[1][0] == "GT"): | |
| correct_guess = 1 | |
| else: | |
| correct_guess = 0 | |
| st.session_state.results.append(correct_guess) | |
| # Passer à la paire suivante | |
| st.session_state.current_index += 1 | |
| st.session_state.start_time = time.time() | |
| st.session_state.page_changed = True | |
| st.rerun() | |
| if st.session_state.page == "detailed_results": | |
| st.title("Detailed Results") | |
| results = np.array(st.session_state.results) | |
| for i in range(NUM_PAIRS): | |
| pair = st.session_state.list_pair[i] | |
| if pair[0][0] == "GT": | |
| imgGT = Image.open(pair[0][1]) | |
| imgGT = imgGT.convert("L") | |
| imgPred = Image.open(pair[1][1]) | |
| imgPred = imgPred.convert("L") | |
| else: | |
| imgGT = Image.open(pair[1][1]) | |
| imgGT = imgGT.convert("L") | |
| imgPred = Image.open(pair[0][1]) | |
| imgPred = imgPred.convert("L") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.image(imgGT, caption="Real", use_container_width=True) | |
| with col2: | |
| st.image(imgPred, caption="AI", use_container_width=True) | |
| result = results[i] | |
| if result: | |
| st.success("✅ Correct") | |
| else: | |
| st.error("❌ Incorrect") | |
| st.markdown("---") | |