Abubakar125 commited on
Commit
6852b64
·
0 Parent(s):

Initial commit of AHDRNet Gradio app

Browse files
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ *.pth filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import cv2
4
+ import numpy as np
5
+ import os
6
+ from pathlib import Path
7
+ from torchvision.transforms.functional import to_pil_image
8
+ import argparse # Needed for the model initialization
9
+
10
+ # --- 1. Import Your Model Architecture ---
11
+ # This assumes your model.py file is in the same directory.
12
+ from model import AHDR
13
+
14
+ # --- 2. Define Constants and Helper Functions ---
15
+ GAMMA = 2.2
16
+
17
+ def load_and_preprocess_images(file_paths):
18
+ """
19
+ Loads images from temporary Gradio file paths, sorts them by brightness,
20
+ selects the 3 most representative ones, and preprocesses them for the model.
21
+ """
22
+ if len(file_paths) < 3:
23
+ raise gr.Error("Please upload at least 3 images for the best results. The model uses under, normal, and over-exposed shots.")
24
+
25
+ images_with_brightness = []
26
+ for file_path in file_paths:
27
+ img_bgr = cv2.imread(file_path)
28
+ if img_bgr is None:
29
+ continue
30
+
31
+ # FIX 1: Corrected the OpenCV constant from COLOR_BGR2GRAYSCALE to COLOR_BGR2GRAY
32
+ gray_img = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
33
+ brightness = np.mean(gray_img)
34
+ images_with_brightness.append({'path': file_path, 'brightness': brightness, 'img': img_bgr})
35
+
36
+ if len(images_with_brightness) < 3:
37
+ raise gr.Error(f"Could only read {len(images_with_brightness)} valid images, but need at least 3.")
38
+
39
+ images_with_brightness.sort(key=lambda x: x['brightness'])
40
+
41
+ selected_images = [
42
+ images_with_brightness[0]['img'],
43
+ images_with_brightness[len(images_with_brightness) // 2]['img'],
44
+ images_with_brightness[-1]['img']
45
+ ]
46
+
47
+ proxy_exposure_times = [0.25, 1.0, 4.0]
48
+ model_inputs = []
49
+ for i, ldr_bgr in enumerate(selected_images):
50
+ ldr_rgb = cv2.cvtColor(ldr_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
51
+ exposure = proxy_exposure_times[i]
52
+ hdr_from_ldr = np.power(ldr_rgb, GAMMA) / exposure
53
+ six_ch_input = np.concatenate((ldr_rgb, hdr_from_ldr), axis=2)
54
+ tensor_input = torch.from_numpy(six_ch_input.transpose(2, 0, 1)).float()
55
+ model_inputs.append(tensor_input)
56
+
57
+ return model_inputs[0].unsqueeze(0), model_inputs[1].unsqueeze(0), model_inputs[2].unsqueeze(0)
58
+
59
+ def postprocess_output(tensor):
60
+ """Converts the model's output tensor to a displayable PIL image."""
61
+ tensor = torch.clamp(tensor, 0, 1)
62
+ return to_pil_image(tensor.squeeze(0).cpu())
63
+
64
+ # --- 3. Load The Model (Done Once at Startup) ---
65
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
66
+ CHECKPOINT_PATH = "./finetuning_run_01/checkpoints/best_model.pth"
67
+
68
+ model_args = argparse.Namespace(nDenselayer=6, growthRate=32, nBlock=16, nFeat=64, nChannel=6)
69
+ model = AHDR(model_args).to(DEVICE)
70
+
71
+ print(f"Loading checkpoint from: {CHECKPOINT_PATH}")
72
+ checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
73
+
74
+ state_dict = checkpoint.get('state_dict', checkpoint)
75
+ state_dict_cleaned = {k.replace('module.', ''): v for k, v in state_dict.items()}
76
+ model.load_state_dict(state_dict_cleaned)
77
+ model.eval()
78
+ print("Model loaded successfully and is in evaluation mode.")
79
+
80
+ # --- 4. Define the Gradio Inference Function ---
81
+ def generate_hdr(file_paths): # FIX 2: The input is now a list of file paths directly
82
+ """
83
+ The main function that Gradio will call. It takes a list of uploaded file paths,
84
+ processes them, and returns the HDR image.
85
+ """
86
+ if file_paths is None or len(file_paths) < 3:
87
+ raise gr.Error("Please upload at least 3 images.")
88
+
89
+ print(f"Processing {len(file_paths)} uploaded images...")
90
+
91
+ try:
92
+ i1, i2, i3 = load_and_preprocess_images(file_paths) # Pass the list of paths directly
93
+ i1, i2, i3 = i1.to(DEVICE), i2.to(DEVICE), i3.to(DEVICE)
94
+
95
+ with torch.no_grad():
96
+ prediction_tensor = model(i1, i2, i3)
97
+
98
+ output_image = postprocess_output(prediction_tensor)
99
+
100
+ print("Successfully generated HDR image.")
101
+ return output_image
102
+
103
+ except Exception as e:
104
+ print(f"An error occurred: {e}")
105
+ raise gr.Error(f"An error occurred during processing: {e}")
106
+
107
+ # --- 5. Create and Launch the Gradio Interface ---
108
+ title = "Attention-guided HDR (AHDRNet)"
109
+ description = """
110
+ Upload a set of multi-exposure LDR images (at least 3 are recommended: under-exposed, normally-exposed, and over-exposed) of the same scene.
111
+ The model will automatically select the three most representative images and merge them into a single, well-exposed High Dynamic Range (HDR) image.
112
+ """
113
+
114
+ example_path = Path("examples")
115
+ examples = [
116
+ [str(p) for p in example_path.glob(f"{d.name}/*")]
117
+ for d in example_path.iterdir()
118
+ if d.is_dir()
119
+ ] if example_path.exists() else []
120
+
121
+ iface = gr.Interface(
122
+ fn=generate_hdr,
123
+ inputs=gr.File(
124
+ file_count="multiple",
125
+ label="Upload LDR Images (JPG, PNG)",
126
+ type="filepath"
127
+ ),
128
+ outputs=gr.Image(type="pil", label="Generated HDR Image"),
129
+ title=title,
130
+ description=description,
131
+ examples=examples,
132
+ # FIX 3: Removed deprecated 'allow_flagging' parameter
133
+ )
134
+
135
+ # if __name__ == "__main__":
136
+ # # FIX 4: Use share=True for temporary local testing.
137
+ # # REMOVE share=True when uploading to Hugging Face Spaces.
138
+ # iface.launch(share=True)
139
+
140
+ if __name__ == "__main__":
141
+ iface.launch()
finetuning_run_01/checkpoints/best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a1a1c03ba47c833deeafc9bc52a46780a1a35060653d864a341ae8a7712298eb
3
+ size 17668173
model.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from torch.autograd import Variable
6
+
7
+
8
+ class make_dilation_dense(nn.Module):
9
+ def __init__(self, nChannels, growthRate, kernel_size=3):
10
+ super(make_dilation_dense, self).__init__()
11
+ self.conv = nn.Conv2d(nChannels, growthRate, kernel_size=kernel_size, padding=(kernel_size-1)//2+1, bias=True, dilation=2)
12
+ def forward(self, x):
13
+ out = F.relu(self.conv(x))
14
+ out = torch.cat((x, out), 1)
15
+ return out
16
+
17
+ # Dilation Residual dense block (DRDB)
18
+ class DRDB(nn.Module):
19
+ def __init__(self, nChannels, nDenselayer, growthRate):
20
+ super(DRDB, self).__init__()
21
+ nChannels_ = nChannels
22
+ modules = []
23
+ for i in range(nDenselayer):
24
+ modules.append(make_dilation_dense(nChannels_, growthRate))
25
+ nChannels_ += growthRate
26
+ self.dense_layers = nn.Sequential(*modules)
27
+ self.conv_1x1 = nn.Conv2d(nChannels_, nChannels, kernel_size=1, padding=0, bias=True)
28
+ def forward(self, x):
29
+ out = self.dense_layers(x)
30
+ out = self.conv_1x1(out)
31
+ out = out + x
32
+ return out
33
+
34
+ # Attention Guided HDR, AHDR-Net
35
+ class AHDR(nn.Module):
36
+ def __init__(self, args):
37
+ super(AHDR, self).__init__()
38
+ nChannel = args.nChannel
39
+ nDenselayer = args.nDenselayer
40
+ nFeat = args.nFeat
41
+ growthRate = args.growthRate
42
+ self.args = args
43
+
44
+ # F-1
45
+ self.conv1 = nn.Conv2d(nChannel, nFeat, kernel_size=3, padding=1, bias=True)
46
+ # F0
47
+ self.conv2 = nn.Conv2d(nFeat*3, nFeat, kernel_size=3, padding=1, bias=True)
48
+ self.att11 = nn.Conv2d(nFeat*2, nFeat*2, kernel_size=3, padding=1, bias=True)
49
+ self.att12 = nn.Conv2d(nFeat*2, nFeat, kernel_size=3, padding=1, bias=True)
50
+ self.attConv1 = nn.Conv2d(nFeat, nFeat, kernel_size=3, padding=1, bias=True)
51
+ self.att31 = nn.Conv2d(nFeat*2, nFeat*2, kernel_size=3, padding=1, bias=True)
52
+ self.att32 = nn.Conv2d(nFeat*2, nFeat, kernel_size=3, padding=1, bias=True)
53
+ self.attConv3 = nn.Conv2d(nFeat, nFeat, kernel_size=3, padding=1, bias=True)
54
+
55
+ # DRDBs 3
56
+ self.RDB1 = DRDB(nFeat, nDenselayer, growthRate)
57
+ self.RDB2 = DRDB(nFeat, nDenselayer, growthRate)
58
+
59
+ self.RDB3 = DRDB(nFeat, nDenselayer, growthRate)
60
+ # feature fusion (GFF)
61
+ self.GFF_1x1 = nn.Conv2d(nFeat*3, nFeat, kernel_size=1, padding=0, bias=True)
62
+ self.GFF_3x3 = nn.Conv2d(nFeat, nFeat, kernel_size=3, padding=1, bias=True)
63
+ # fusion
64
+ self.conv_up = nn.Conv2d(nFeat, nFeat, kernel_size=3, padding=1, bias=True)
65
+
66
+ # conv
67
+ self.conv3 = nn.Conv2d(nFeat, 3, kernel_size=3, padding=1, bias=True)
68
+ self.relu = nn.LeakyReLU()
69
+
70
+
71
+ def forward(self, x1, x2, x3):
72
+
73
+ F1_ = self.relu(self.conv1(x1))
74
+ F2_ = self.relu(self.conv1(x2))
75
+ F3_ = self.relu(self.conv1(x3))
76
+
77
+ F1_i = torch.cat((F1_, F2_), 1)
78
+ F1_A = self.relu(self.att11(F1_i))
79
+ F1_A = self.att12(F1_A)
80
+ F1_A = nn.functional.sigmoid(F1_A)
81
+ F1_ = F1_ * F1_A
82
+
83
+
84
+ F3_i = torch.cat((F3_, F2_), 1)
85
+ F3_A = self.relu(self.att31(F3_i))
86
+ F3_A = self.att32(F3_A)
87
+ F3_A = nn.functional.sigmoid(F3_A)
88
+ F3_ = F3_ * F3_A
89
+
90
+ F_ = torch.cat((F1_, F2_, F3_), 1)
91
+
92
+ F_0 = self.conv2(F_)
93
+ F_1 = self.RDB1(F_0)
94
+ F_2 = self.RDB2(F_1)
95
+ F_3 = self.RDB3(F_2)
96
+ FF = torch.cat((F_1, F_2, F_3), 1)
97
+ FdLF = self.GFF_1x1(FF)
98
+ FGF = self.GFF_3x3(FdLF)
99
+ FDF = FGF + F2_
100
+ us = self.conv_up(FDF)
101
+
102
+ output = self.conv3(us)
103
+ output = nn.functional.sigmoid(output)
104
+
105
+ return output
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ gradio
4
+ opencv-python-headless
5
+ numpy
utils.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+
4
+ def mk_dir(dir_path):
5
+ if not os.path.exists(dir_path):
6
+ os.makedirs(dir_path)
7
+
8
+ def model_load(model, trained_model_dir, model_file_name):
9
+ model_path = os.path.join(trained_model_dir, model_file_name)
10
+ # trained_model_dir + model_file_name # '/modelParas.pkl'
11
+ model.load_state_dict(torch.load(model_path))
12
+ return model