Spaces:
Sleeping
Sleeping
Commit ·
6852b64
0
Parent(s):
Initial commit of AHDRNet Gradio app
Browse files- .gitattributes +1 -0
- app.py +141 -0
- finetuning_run_01/checkpoints/best_model.pth +3 -0
- model.py +105 -0
- requirements.txt +5 -0
- utils.py +12 -0
.gitattributes
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
import os
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from torchvision.transforms.functional import to_pil_image
|
| 8 |
+
import argparse # Needed for the model initialization
|
| 9 |
+
|
| 10 |
+
# --- 1. Import Your Model Architecture ---
|
| 11 |
+
# This assumes your model.py file is in the same directory.
|
| 12 |
+
from model import AHDR
|
| 13 |
+
|
| 14 |
+
# --- 2. Define Constants and Helper Functions ---
|
| 15 |
+
GAMMA = 2.2
|
| 16 |
+
|
| 17 |
+
def load_and_preprocess_images(file_paths):
|
| 18 |
+
"""
|
| 19 |
+
Loads images from temporary Gradio file paths, sorts them by brightness,
|
| 20 |
+
selects the 3 most representative ones, and preprocesses them for the model.
|
| 21 |
+
"""
|
| 22 |
+
if len(file_paths) < 3:
|
| 23 |
+
raise gr.Error("Please upload at least 3 images for the best results. The model uses under, normal, and over-exposed shots.")
|
| 24 |
+
|
| 25 |
+
images_with_brightness = []
|
| 26 |
+
for file_path in file_paths:
|
| 27 |
+
img_bgr = cv2.imread(file_path)
|
| 28 |
+
if img_bgr is None:
|
| 29 |
+
continue
|
| 30 |
+
|
| 31 |
+
# FIX 1: Corrected the OpenCV constant from COLOR_BGR2GRAYSCALE to COLOR_BGR2GRAY
|
| 32 |
+
gray_img = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
|
| 33 |
+
brightness = np.mean(gray_img)
|
| 34 |
+
images_with_brightness.append({'path': file_path, 'brightness': brightness, 'img': img_bgr})
|
| 35 |
+
|
| 36 |
+
if len(images_with_brightness) < 3:
|
| 37 |
+
raise gr.Error(f"Could only read {len(images_with_brightness)} valid images, but need at least 3.")
|
| 38 |
+
|
| 39 |
+
images_with_brightness.sort(key=lambda x: x['brightness'])
|
| 40 |
+
|
| 41 |
+
selected_images = [
|
| 42 |
+
images_with_brightness[0]['img'],
|
| 43 |
+
images_with_brightness[len(images_with_brightness) // 2]['img'],
|
| 44 |
+
images_with_brightness[-1]['img']
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
proxy_exposure_times = [0.25, 1.0, 4.0]
|
| 48 |
+
model_inputs = []
|
| 49 |
+
for i, ldr_bgr in enumerate(selected_images):
|
| 50 |
+
ldr_rgb = cv2.cvtColor(ldr_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
|
| 51 |
+
exposure = proxy_exposure_times[i]
|
| 52 |
+
hdr_from_ldr = np.power(ldr_rgb, GAMMA) / exposure
|
| 53 |
+
six_ch_input = np.concatenate((ldr_rgb, hdr_from_ldr), axis=2)
|
| 54 |
+
tensor_input = torch.from_numpy(six_ch_input.transpose(2, 0, 1)).float()
|
| 55 |
+
model_inputs.append(tensor_input)
|
| 56 |
+
|
| 57 |
+
return model_inputs[0].unsqueeze(0), model_inputs[1].unsqueeze(0), model_inputs[2].unsqueeze(0)
|
| 58 |
+
|
| 59 |
+
def postprocess_output(tensor):
|
| 60 |
+
"""Converts the model's output tensor to a displayable PIL image."""
|
| 61 |
+
tensor = torch.clamp(tensor, 0, 1)
|
| 62 |
+
return to_pil_image(tensor.squeeze(0).cpu())
|
| 63 |
+
|
| 64 |
+
# --- 3. Load The Model (Done Once at Startup) ---
|
| 65 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 66 |
+
CHECKPOINT_PATH = "./finetuning_run_01/checkpoints/best_model.pth"
|
| 67 |
+
|
| 68 |
+
model_args = argparse.Namespace(nDenselayer=6, growthRate=32, nBlock=16, nFeat=64, nChannel=6)
|
| 69 |
+
model = AHDR(model_args).to(DEVICE)
|
| 70 |
+
|
| 71 |
+
print(f"Loading checkpoint from: {CHECKPOINT_PATH}")
|
| 72 |
+
checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
|
| 73 |
+
|
| 74 |
+
state_dict = checkpoint.get('state_dict', checkpoint)
|
| 75 |
+
state_dict_cleaned = {k.replace('module.', ''): v for k, v in state_dict.items()}
|
| 76 |
+
model.load_state_dict(state_dict_cleaned)
|
| 77 |
+
model.eval()
|
| 78 |
+
print("Model loaded successfully and is in evaluation mode.")
|
| 79 |
+
|
| 80 |
+
# --- 4. Define the Gradio Inference Function ---
|
| 81 |
+
def generate_hdr(file_paths): # FIX 2: The input is now a list of file paths directly
|
| 82 |
+
"""
|
| 83 |
+
The main function that Gradio will call. It takes a list of uploaded file paths,
|
| 84 |
+
processes them, and returns the HDR image.
|
| 85 |
+
"""
|
| 86 |
+
if file_paths is None or len(file_paths) < 3:
|
| 87 |
+
raise gr.Error("Please upload at least 3 images.")
|
| 88 |
+
|
| 89 |
+
print(f"Processing {len(file_paths)} uploaded images...")
|
| 90 |
+
|
| 91 |
+
try:
|
| 92 |
+
i1, i2, i3 = load_and_preprocess_images(file_paths) # Pass the list of paths directly
|
| 93 |
+
i1, i2, i3 = i1.to(DEVICE), i2.to(DEVICE), i3.to(DEVICE)
|
| 94 |
+
|
| 95 |
+
with torch.no_grad():
|
| 96 |
+
prediction_tensor = model(i1, i2, i3)
|
| 97 |
+
|
| 98 |
+
output_image = postprocess_output(prediction_tensor)
|
| 99 |
+
|
| 100 |
+
print("Successfully generated HDR image.")
|
| 101 |
+
return output_image
|
| 102 |
+
|
| 103 |
+
except Exception as e:
|
| 104 |
+
print(f"An error occurred: {e}")
|
| 105 |
+
raise gr.Error(f"An error occurred during processing: {e}")
|
| 106 |
+
|
| 107 |
+
# --- 5. Create and Launch the Gradio Interface ---
|
| 108 |
+
title = "Attention-guided HDR (AHDRNet)"
|
| 109 |
+
description = """
|
| 110 |
+
Upload a set of multi-exposure LDR images (at least 3 are recommended: under-exposed, normally-exposed, and over-exposed) of the same scene.
|
| 111 |
+
The model will automatically select the three most representative images and merge them into a single, well-exposed High Dynamic Range (HDR) image.
|
| 112 |
+
"""
|
| 113 |
+
|
| 114 |
+
example_path = Path("examples")
|
| 115 |
+
examples = [
|
| 116 |
+
[str(p) for p in example_path.glob(f"{d.name}/*")]
|
| 117 |
+
for d in example_path.iterdir()
|
| 118 |
+
if d.is_dir()
|
| 119 |
+
] if example_path.exists() else []
|
| 120 |
+
|
| 121 |
+
iface = gr.Interface(
|
| 122 |
+
fn=generate_hdr,
|
| 123 |
+
inputs=gr.File(
|
| 124 |
+
file_count="multiple",
|
| 125 |
+
label="Upload LDR Images (JPG, PNG)",
|
| 126 |
+
type="filepath"
|
| 127 |
+
),
|
| 128 |
+
outputs=gr.Image(type="pil", label="Generated HDR Image"),
|
| 129 |
+
title=title,
|
| 130 |
+
description=description,
|
| 131 |
+
examples=examples,
|
| 132 |
+
# FIX 3: Removed deprecated 'allow_flagging' parameter
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
# if __name__ == "__main__":
|
| 136 |
+
# # FIX 4: Use share=True for temporary local testing.
|
| 137 |
+
# # REMOVE share=True when uploading to Hugging Face Spaces.
|
| 138 |
+
# iface.launch(share=True)
|
| 139 |
+
|
| 140 |
+
if __name__ == "__main__":
|
| 141 |
+
iface.launch()
|
finetuning_run_01/checkpoints/best_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a1a1c03ba47c833deeafc9bc52a46780a1a35060653d864a341ae8a7712298eb
|
| 3 |
+
size 17668173
|
model.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import numpy as np
|
| 5 |
+
from torch.autograd import Variable
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class make_dilation_dense(nn.Module):
|
| 9 |
+
def __init__(self, nChannels, growthRate, kernel_size=3):
|
| 10 |
+
super(make_dilation_dense, self).__init__()
|
| 11 |
+
self.conv = nn.Conv2d(nChannels, growthRate, kernel_size=kernel_size, padding=(kernel_size-1)//2+1, bias=True, dilation=2)
|
| 12 |
+
def forward(self, x):
|
| 13 |
+
out = F.relu(self.conv(x))
|
| 14 |
+
out = torch.cat((x, out), 1)
|
| 15 |
+
return out
|
| 16 |
+
|
| 17 |
+
# Dilation Residual dense block (DRDB)
|
| 18 |
+
class DRDB(nn.Module):
|
| 19 |
+
def __init__(self, nChannels, nDenselayer, growthRate):
|
| 20 |
+
super(DRDB, self).__init__()
|
| 21 |
+
nChannels_ = nChannels
|
| 22 |
+
modules = []
|
| 23 |
+
for i in range(nDenselayer):
|
| 24 |
+
modules.append(make_dilation_dense(nChannels_, growthRate))
|
| 25 |
+
nChannels_ += growthRate
|
| 26 |
+
self.dense_layers = nn.Sequential(*modules)
|
| 27 |
+
self.conv_1x1 = nn.Conv2d(nChannels_, nChannels, kernel_size=1, padding=0, bias=True)
|
| 28 |
+
def forward(self, x):
|
| 29 |
+
out = self.dense_layers(x)
|
| 30 |
+
out = self.conv_1x1(out)
|
| 31 |
+
out = out + x
|
| 32 |
+
return out
|
| 33 |
+
|
| 34 |
+
# Attention Guided HDR, AHDR-Net
|
| 35 |
+
class AHDR(nn.Module):
|
| 36 |
+
def __init__(self, args):
|
| 37 |
+
super(AHDR, self).__init__()
|
| 38 |
+
nChannel = args.nChannel
|
| 39 |
+
nDenselayer = args.nDenselayer
|
| 40 |
+
nFeat = args.nFeat
|
| 41 |
+
growthRate = args.growthRate
|
| 42 |
+
self.args = args
|
| 43 |
+
|
| 44 |
+
# F-1
|
| 45 |
+
self.conv1 = nn.Conv2d(nChannel, nFeat, kernel_size=3, padding=1, bias=True)
|
| 46 |
+
# F0
|
| 47 |
+
self.conv2 = nn.Conv2d(nFeat*3, nFeat, kernel_size=3, padding=1, bias=True)
|
| 48 |
+
self.att11 = nn.Conv2d(nFeat*2, nFeat*2, kernel_size=3, padding=1, bias=True)
|
| 49 |
+
self.att12 = nn.Conv2d(nFeat*2, nFeat, kernel_size=3, padding=1, bias=True)
|
| 50 |
+
self.attConv1 = nn.Conv2d(nFeat, nFeat, kernel_size=3, padding=1, bias=True)
|
| 51 |
+
self.att31 = nn.Conv2d(nFeat*2, nFeat*2, kernel_size=3, padding=1, bias=True)
|
| 52 |
+
self.att32 = nn.Conv2d(nFeat*2, nFeat, kernel_size=3, padding=1, bias=True)
|
| 53 |
+
self.attConv3 = nn.Conv2d(nFeat, nFeat, kernel_size=3, padding=1, bias=True)
|
| 54 |
+
|
| 55 |
+
# DRDBs 3
|
| 56 |
+
self.RDB1 = DRDB(nFeat, nDenselayer, growthRate)
|
| 57 |
+
self.RDB2 = DRDB(nFeat, nDenselayer, growthRate)
|
| 58 |
+
|
| 59 |
+
self.RDB3 = DRDB(nFeat, nDenselayer, growthRate)
|
| 60 |
+
# feature fusion (GFF)
|
| 61 |
+
self.GFF_1x1 = nn.Conv2d(nFeat*3, nFeat, kernel_size=1, padding=0, bias=True)
|
| 62 |
+
self.GFF_3x3 = nn.Conv2d(nFeat, nFeat, kernel_size=3, padding=1, bias=True)
|
| 63 |
+
# fusion
|
| 64 |
+
self.conv_up = nn.Conv2d(nFeat, nFeat, kernel_size=3, padding=1, bias=True)
|
| 65 |
+
|
| 66 |
+
# conv
|
| 67 |
+
self.conv3 = nn.Conv2d(nFeat, 3, kernel_size=3, padding=1, bias=True)
|
| 68 |
+
self.relu = nn.LeakyReLU()
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def forward(self, x1, x2, x3):
|
| 72 |
+
|
| 73 |
+
F1_ = self.relu(self.conv1(x1))
|
| 74 |
+
F2_ = self.relu(self.conv1(x2))
|
| 75 |
+
F3_ = self.relu(self.conv1(x3))
|
| 76 |
+
|
| 77 |
+
F1_i = torch.cat((F1_, F2_), 1)
|
| 78 |
+
F1_A = self.relu(self.att11(F1_i))
|
| 79 |
+
F1_A = self.att12(F1_A)
|
| 80 |
+
F1_A = nn.functional.sigmoid(F1_A)
|
| 81 |
+
F1_ = F1_ * F1_A
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
F3_i = torch.cat((F3_, F2_), 1)
|
| 85 |
+
F3_A = self.relu(self.att31(F3_i))
|
| 86 |
+
F3_A = self.att32(F3_A)
|
| 87 |
+
F3_A = nn.functional.sigmoid(F3_A)
|
| 88 |
+
F3_ = F3_ * F3_A
|
| 89 |
+
|
| 90 |
+
F_ = torch.cat((F1_, F2_, F3_), 1)
|
| 91 |
+
|
| 92 |
+
F_0 = self.conv2(F_)
|
| 93 |
+
F_1 = self.RDB1(F_0)
|
| 94 |
+
F_2 = self.RDB2(F_1)
|
| 95 |
+
F_3 = self.RDB3(F_2)
|
| 96 |
+
FF = torch.cat((F_1, F_2, F_3), 1)
|
| 97 |
+
FdLF = self.GFF_1x1(FF)
|
| 98 |
+
FGF = self.GFF_3x3(FdLF)
|
| 99 |
+
FDF = FGF + F2_
|
| 100 |
+
us = self.conv_up(FDF)
|
| 101 |
+
|
| 102 |
+
output = self.conv3(us)
|
| 103 |
+
output = nn.functional.sigmoid(output)
|
| 104 |
+
|
| 105 |
+
return output
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchvision
|
| 3 |
+
gradio
|
| 4 |
+
opencv-python-headless
|
| 5 |
+
numpy
|
utils.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
def mk_dir(dir_path):
|
| 5 |
+
if not os.path.exists(dir_path):
|
| 6 |
+
os.makedirs(dir_path)
|
| 7 |
+
|
| 8 |
+
def model_load(model, trained_model_dir, model_file_name):
|
| 9 |
+
model_path = os.path.join(trained_model_dir, model_file_name)
|
| 10 |
+
# trained_model_dir + model_file_name # '/modelParas.pkl'
|
| 11 |
+
model.load_state_dict(torch.load(model_path))
|
| 12 |
+
return model
|