Pushpesh commited on
Commit
a9a2d42
·
1 Parent(s): 1037c89
Files changed (5) hide show
  1. app/__init__.py +0 -0
  2. app/app.py +28 -0
  3. app/utils.py +61 -0
  4. main.py +30 -0
  5. model/model_epoch_49.pth +3 -0
app/__init__.py ADDED
File without changes
app/app.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ import numpy as np
5
+ from .utils import recover_light_sources
6
+
7
+ device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+ model=torch.load('model/model_epoch_49.pth',map_location=device)
9
+
10
+ def evaluate(model,image):
11
+ model.eval()
12
+ with torch.no_grad():
13
+ image = image.to(device)
14
+ #outputs= model(image.unsqueeze(0))
15
+ outputs= model(image)
16
+ return outputs.squeeze(0).squeeze(0).cpu()
17
+
18
+ def predict(input_image):
19
+ #input_image=Image.open(inp_img).convert('RGB')
20
+ input_image=input_image.resize((512,512))
21
+ input_image_torch=torch.tensor(np.array(input_image)).permute(2,0,1).unsqueeze(0).float()/255.0
22
+ mask=evaluate(model,input_image_torch)
23
+ mask=mask.permute(1,2,0).numpy()
24
+ final_img=recover_light_sources(mask,input_image)
25
+ return final_img
26
+
27
+ demo = gr.Interface(fn=predict, inputs=gr.Image(type="pil"),outputs=gr.Image(), examples=["test_imgs/test1.png", "test_imgs/test2.png","test_imgs/test3.png"])
28
+ demo.launch()
app/utils.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ import numpy as np
4
+
5
+ device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
6
+ model=torch.load('model/Unet_segformer.pth',map_location=device)
7
+
8
+ def evaluate(model,image):
9
+ model.eval()
10
+ with torch.no_grad():
11
+ image = image.to(device)
12
+ #outputs= model(image.unsqueeze(0))
13
+ outputs= model(image)
14
+ return outputs.squeeze(0).squeeze(0).cpu()
15
+
16
+ def predict(input_image):
17
+ #input_image=Image.open(inp_img).convert('RGB')
18
+ input_image=input_image.resize((512,512))
19
+ input_image_torch=torch.tensor(np.array(input_image)).permute(2,0,1).unsqueeze(0).float()/255.0
20
+ mask=evaluate(model,input_image_torch)
21
+ mask=mask.permute(1,2,0).numpy()
22
+ return mask
23
+
24
+ def calculate_input_illuminance(image):
25
+ """
26
+ Calculate illuminance: I_input = C_r + C_g + C_b
27
+ """
28
+ return np.sum(image, axis=2)
29
+
30
+ def generate_recovery_weight_matrix(illuminance_matrix, alpha=15):
31
+ """
32
+ Generate recovery weights using power function
33
+ Formula: W_r = ((I_input - min) / (max - min))^α
34
+ """
35
+ I_min = np.min(illuminance_matrix)
36
+ I_max = np.max(illuminance_matrix)
37
+
38
+ if I_max == I_min:
39
+ normalized = np.zeros_like(illuminance_matrix)
40
+ else:
41
+ normalized = (illuminance_matrix - I_min) / (I_max - I_min)
42
+
43
+ # Apply power function with α = 15
44
+ W_r = np.power(normalized, alpha)
45
+ return W_r
46
+
47
+ def recover_light_sources(original_image, network_output, alpha=15):
48
+ """
49
+ Final recovery: I_final = (1 - W_r) ⊙ N(C) + W_r ⊙ C
50
+ """
51
+ # Calculate illuminance and recovery weights
52
+ I_input = calculate_input_illuminance(original_image)
53
+ W_r = generate_recovery_weight_matrix(I_input, alpha)
54
+
55
+ # Expand to match image dimensions
56
+ W_r_expanded = np.expand_dims(W_r, axis=2)
57
+ W_r_expanded = np.repeat(W_r_expanded, 3, axis=2)
58
+
59
+ # Convex combination for light source recovery
60
+ I_final = (1 - W_r_expanded) * network_output + W_r_expanded * original_image
61
+ return np.clip(I_final, 0, 255).astype(np.uint8)
main.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile
2
+ from pydantic import BaseModel
3
+ from utils import predict
4
+ from starlette.responses import StreamingResponse, Response
5
+ from io import BytesIO
6
+ from PIL import Image
7
+ app=FastAPI()
8
+
9
+
10
+ @app.get("/")
11
+ def home():
12
+ return {"health_check":"OK"}
13
+
14
+ #@app.post("/remove")
15
+ #async def remove(file: UploadFile = File(...)):
16
+ # contents=await file.read()
17
+ # image = Image.open(BytesIO(contents)).convert("RGB")
18
+ # mask=predict(image)
19
+ # return {"mask": mask.tolist()}
20
+
21
+ @app.post("/remove")
22
+ async def remove(file: UploadFile = File(...)):
23
+ contents=await file.read()
24
+ image = Image.open(BytesIO(contents)).convert("RGB")
25
+ mask=predict(image)
26
+ img=Image.fromarray((mask*255).astype('uint8'))
27
+ buffered=BytesIO()
28
+ img.save(buffered,format="JPEG")
29
+ img_str=buffered.getvalue()
30
+ return Response(content=img_str, media_type="image/jpeg")
model/model_epoch_49.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d03f06fc9cf54ceaf33ad68481c41c3a281137f262d356898aadafcd6ec9ae85
3
+ size 197469247