RealWoodOrAI / app.py
CarolineM5's picture
Upload app.py
a7969be verified
# -*- coding: utf-8 -*-
"""
Created on Mon Apr 7 13:43:34 2025
@author: camaac
"""
import streamlit as st
import os
import random
import pandas as pd
from PIL import Image, ImageEnhance
import numpy as np
import gspread
from oauth2client.service_account import ServiceAccountCredentials
from skimage.exposure import match_histograms
import time
from streamlit_autorefresh import st_autorefresh
# -------------------------
# Global parameters
# -------------------------
IMAGE_DIR = "images" # Folder containing images
NUM_PAIRS = 25 # Total number of pairs to be assessed
RESULTS_FILE = "results.csv" # CSV file for saving responses
# -------------------------
# Helper functions
# -------------------------
def load_image_pair(index):
"""
For a given index (integer), returns the path of the ground truth and the path of AI generated image.
Files are named with a 5-digit index.
"""
idx_str = str(index).zfill(5)
gt_path = os.path.join(IMAGE_DIR, f"{idx_str}.png")
pred_path = os.path.join(IMAGE_DIR, f"{idx_str}_gen0.png")
return gt_path, pred_path
def match_brightness(source_img, target_img):
source_brightness = np.mean(np.array(source_img))
target_brightness = np.mean(np.array(target_img))
if target_brightness == 0:
factor = 1 # avoid division by zero
else:
factor = source_brightness / target_brightness
enhancer = ImageEnhance.Brightness(target_img)
adjusted = enhancer.enhance(factor)
return adjusted
def match_histograms_pil(img_reference, img_to_adjust):
"""
Layer the histogram of `img_reference` on `img_to_adjust`
(both images are PIL.Image objects).
Returns a PIL image with adjusted histogram.
"""
# Convertir les deux images en tableaux numpy
ref_array = np.array(img_reference)
adj_array = np.array(img_to_adjust)
# Ajuster l'histogramme
matched = match_histograms(adj_array, ref_array, channel_axis=-1)
# Reconvertir en image PIL
matched_img = Image.fromarray(np.uint8(matched))
return matched_img
# -------------------------
# Navigation via st.session_state
# -------------------------
if "page" not in st.session_state:
st.session_state.page = "intro"
if "user_name" not in st.session_state:
st.session_state.user_name = ""
if "current_index" not in st.session_state:
st.session_state.current_index = 0
if "results" not in st.session_state:
st.session_state.results = []
if "list_pair" not in st.session_state:
st.session_state.list_pair = []
if "list_pair_ID" not in st.session_state:
st.session_state.list_pair_ID = []
if "results_tot" not in st.session_state:
st.session_state.results_tot = 0
if "submitted" not in st.session_state:
st.session_state.submitted = False
# -------------------------
# Intro page
# -------------------------
if st.session_state.page == "intro":
st.title("AI Wood Generation Evaluation Study")
st.markdown(
"""
**Welcome!**
In this study, you will be shown pairs of wood surface images.
One image is a real photograph and the other is generated by AI.
Your task is to select the image you believe is **real**.
⌛ *Each image pair will be visible for 10 seconds only, be quick!* ⌛
Please enter your name below and click **Start Evaluation** to begin.
"""
)
name = st.text_input("Enter your name:")
if st.button("Start Evaluation") and name:
st.session_state.user_name = name
st.session_state.page = "evaluation"
st.rerun()
st.session_state.list_pair_ID = random.sample(range(1, 51), NUM_PAIRS)
for i, index in enumerate(st.session_state.list_pair_ID):
gt_path, pred_path = load_image_pair(index)
pair = [("GT", gt_path), ("Pred", pred_path)]
random.shuffle(pair)
st.session_state.list_pair.append(pair)
st.stop()
# -------------------------
# Evaluation page
# -------------------------
if st.session_state.page == "evaluation":
st.title("AI Wood Generation Evaluation")
# st.write(f"User: **{st.session_state.user_name}**")
if "start_time" not in st.session_state or st.session_state.page_changed:
st.session_state.start_time = time.time()
st.session_state.page_changed = False
# If all pairs have been evaluated, display a message and save the results
if st.session_state.current_index+1 > NUM_PAIRS:
st.markdown("<h4>How confident were you in your answers?</h4>", unsafe_allow_html=True)
confidence = st.radio( #st.select_slider
" ",
[
"Not confident at all",
"Slightly confident",
"Moderately confident",
"Very confident",
"Extremely confident"
],
index=2, #value="Moderately confident"
horizontal=True
)
if st.button("Submit"):
#Calculating result
correct_guess = np.array(st.session_state.results)
nb_correct = np.sum(correct_guess)
st.session_state.results_tot = nb_correct
st.success(f"Number of correct answers: {nb_correct}/{NUM_PAIRS}")
st.success("Thank you for completing the evaluation!", icon=":material/park:")
#Save result
scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
creds = ServiceAccountCredentials.from_json_keyfile_name('glass-flux-456209-d4-6fc4b7d9d274.json', scope)
client = gspread.authorize(creds)
sh = client.open('Results_woodAI').worksheet('test')
row = [
st.session_state.user_name,
confidence,
str(nb_correct),
",".join(map(str, st.session_state.list_pair_ID)),
",".join(map(str, correct_guess)),
]
sh.append_row(row)
st.session_state.submitted = True
# if st.button("See detailed results"):
# st.session_state.page = "detailed_results"
# st.rerun()
if st.session_state.submitted == True:
if st.button("See detailed results"):
st.session_state.page = "detailed_results"
st.rerun()
# st.stop()
st.stop()
st_autorefresh(interval=1000, key=f"timer_{st.session_state.current_index}")
st.write(f"Image Pair {st.session_state.current_index+1} of {NUM_PAIRS}")
# Charger et mélanger la paire pour l'index courant
pair = st.session_state.list_pair[st.session_state.current_index]
img1 = Image.open(pair[0][1])
img1 = img1.convert("L")
img2 = Image.open(pair[1][1])
img2 = img2.convert("L")
# if pair[0][0] == "GT":
# img2 = match_brightness(img1, img2)
# else:
# img1 = match_brightness(img2, img1)
elapsed = time.time() - st.session_state.start_time
remaining = max(0, 10 - int(elapsed))
st.markdown(f"**Time remaining:** {remaining} seconds")
percent = int((remaining / 10) * 100)
st.progress(percent)
placeholder = Image.new("L", img1.size, 128)
col1, col2 = st.columns(2)
with col1:
if elapsed < 10:
st.image(img1, caption="Image 1", use_container_width=True)
else:
st.image(placeholder, caption="Time’s up!", use_container_width=True)
with col2:
if elapsed < 10:
st.image(img2, caption="Image 2", use_container_width=True)
else:
st.image(placeholder, caption="Time’s up!", use_container_width=True)
choice = st.radio("Select the real image: ", options=["1", "2"], horizontal = True) #, index=None
if st.button("Next"):
if (choice == "1" and pair[0][0] == "GT") or (choice == "2" and pair[1][0] == "GT"):
correct_guess = 1
else:
correct_guess = 0
st.session_state.results.append(correct_guess)
# Passer à la paire suivante
st.session_state.current_index += 1
st.session_state.start_time = time.time()
st.session_state.page_changed = True
st.rerun()
if st.session_state.page == "detailed_results":
st.title("Detailed Results")
results = np.array(st.session_state.results)
for i in range(NUM_PAIRS):
pair = st.session_state.list_pair[i]
if pair[0][0] == "GT":
imgGT = Image.open(pair[0][1])
imgGT = imgGT.convert("L")
imgPred = Image.open(pair[1][1])
imgPred = imgPred.convert("L")
else:
imgGT = Image.open(pair[1][1])
imgGT = imgGT.convert("L")
imgPred = Image.open(pair[0][1])
imgPred = imgPred.convert("L")
col1, col2 = st.columns(2)
with col1:
st.image(imgGT, caption="Real", use_container_width=True)
with col2:
st.image(imgPred, caption="AI", use_container_width=True)
result = results[i]
if result:
st.success("✅ Correct")
else:
st.error("❌ Incorrect")
st.markdown("---")