Spaces:
Running
Running
Pushpesh
commited on
Commit
·
a9a2d42
1
Parent(s):
1037c89
Base
Browse files- app/__init__.py +0 -0
- app/app.py +28 -0
- app/utils.py +61 -0
- main.py +30 -0
- model/model_epoch_49.pth +3 -0
app/__init__.py
ADDED
|
File without changes
|
app/app.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import numpy as np
|
| 5 |
+
from .utils import recover_light_sources
|
| 6 |
+
|
| 7 |
+
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 8 |
+
model=torch.load('model/model_epoch_49.pth',map_location=device)
|
| 9 |
+
|
| 10 |
+
def evaluate(model,image):
|
| 11 |
+
model.eval()
|
| 12 |
+
with torch.no_grad():
|
| 13 |
+
image = image.to(device)
|
| 14 |
+
#outputs= model(image.unsqueeze(0))
|
| 15 |
+
outputs= model(image)
|
| 16 |
+
return outputs.squeeze(0).squeeze(0).cpu()
|
| 17 |
+
|
| 18 |
+
def predict(input_image):
|
| 19 |
+
#input_image=Image.open(inp_img).convert('RGB')
|
| 20 |
+
input_image=input_image.resize((512,512))
|
| 21 |
+
input_image_torch=torch.tensor(np.array(input_image)).permute(2,0,1).unsqueeze(0).float()/255.0
|
| 22 |
+
mask=evaluate(model,input_image_torch)
|
| 23 |
+
mask=mask.permute(1,2,0).numpy()
|
| 24 |
+
final_img=recover_light_sources(mask,input_image)
|
| 25 |
+
return final_img
|
| 26 |
+
|
| 27 |
+
demo = gr.Interface(fn=predict, inputs=gr.Image(type="pil"),outputs=gr.Image(), examples=["test_imgs/test1.png", "test_imgs/test2.png","test_imgs/test3.png"])
|
| 28 |
+
demo.launch()
|
app/utils.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 6 |
+
model=torch.load('model/Unet_segformer.pth',map_location=device)
|
| 7 |
+
|
| 8 |
+
def evaluate(model,image):
|
| 9 |
+
model.eval()
|
| 10 |
+
with torch.no_grad():
|
| 11 |
+
image = image.to(device)
|
| 12 |
+
#outputs= model(image.unsqueeze(0))
|
| 13 |
+
outputs= model(image)
|
| 14 |
+
return outputs.squeeze(0).squeeze(0).cpu()
|
| 15 |
+
|
| 16 |
+
def predict(input_image):
|
| 17 |
+
#input_image=Image.open(inp_img).convert('RGB')
|
| 18 |
+
input_image=input_image.resize((512,512))
|
| 19 |
+
input_image_torch=torch.tensor(np.array(input_image)).permute(2,0,1).unsqueeze(0).float()/255.0
|
| 20 |
+
mask=evaluate(model,input_image_torch)
|
| 21 |
+
mask=mask.permute(1,2,0).numpy()
|
| 22 |
+
return mask
|
| 23 |
+
|
| 24 |
+
def calculate_input_illuminance(image):
|
| 25 |
+
"""
|
| 26 |
+
Calculate illuminance: I_input = C_r + C_g + C_b
|
| 27 |
+
"""
|
| 28 |
+
return np.sum(image, axis=2)
|
| 29 |
+
|
| 30 |
+
def generate_recovery_weight_matrix(illuminance_matrix, alpha=15):
|
| 31 |
+
"""
|
| 32 |
+
Generate recovery weights using power function
|
| 33 |
+
Formula: W_r = ((I_input - min) / (max - min))^α
|
| 34 |
+
"""
|
| 35 |
+
I_min = np.min(illuminance_matrix)
|
| 36 |
+
I_max = np.max(illuminance_matrix)
|
| 37 |
+
|
| 38 |
+
if I_max == I_min:
|
| 39 |
+
normalized = np.zeros_like(illuminance_matrix)
|
| 40 |
+
else:
|
| 41 |
+
normalized = (illuminance_matrix - I_min) / (I_max - I_min)
|
| 42 |
+
|
| 43 |
+
# Apply power function with α = 15
|
| 44 |
+
W_r = np.power(normalized, alpha)
|
| 45 |
+
return W_r
|
| 46 |
+
|
| 47 |
+
def recover_light_sources(original_image, network_output, alpha=15):
|
| 48 |
+
"""
|
| 49 |
+
Final recovery: I_final = (1 - W_r) ⊙ N(C) + W_r ⊙ C
|
| 50 |
+
"""
|
| 51 |
+
# Calculate illuminance and recovery weights
|
| 52 |
+
I_input = calculate_input_illuminance(original_image)
|
| 53 |
+
W_r = generate_recovery_weight_matrix(I_input, alpha)
|
| 54 |
+
|
| 55 |
+
# Expand to match image dimensions
|
| 56 |
+
W_r_expanded = np.expand_dims(W_r, axis=2)
|
| 57 |
+
W_r_expanded = np.repeat(W_r_expanded, 3, axis=2)
|
| 58 |
+
|
| 59 |
+
# Convex combination for light source recovery
|
| 60 |
+
I_final = (1 - W_r_expanded) * network_output + W_r_expanded * original_image
|
| 61 |
+
return np.clip(I_final, 0, 255).astype(np.uint8)
|
main.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, File, UploadFile
|
| 2 |
+
from pydantic import BaseModel
|
| 3 |
+
from utils import predict
|
| 4 |
+
from starlette.responses import StreamingResponse, Response
|
| 5 |
+
from io import BytesIO
|
| 6 |
+
from PIL import Image
|
| 7 |
+
app=FastAPI()
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@app.get("/")
|
| 11 |
+
def home():
|
| 12 |
+
return {"health_check":"OK"}
|
| 13 |
+
|
| 14 |
+
#@app.post("/remove")
|
| 15 |
+
#async def remove(file: UploadFile = File(...)):
|
| 16 |
+
# contents=await file.read()
|
| 17 |
+
# image = Image.open(BytesIO(contents)).convert("RGB")
|
| 18 |
+
# mask=predict(image)
|
| 19 |
+
# return {"mask": mask.tolist()}
|
| 20 |
+
|
| 21 |
+
@app.post("/remove")
|
| 22 |
+
async def remove(file: UploadFile = File(...)):
|
| 23 |
+
contents=await file.read()
|
| 24 |
+
image = Image.open(BytesIO(contents)).convert("RGB")
|
| 25 |
+
mask=predict(image)
|
| 26 |
+
img=Image.fromarray((mask*255).astype('uint8'))
|
| 27 |
+
buffered=BytesIO()
|
| 28 |
+
img.save(buffered,format="JPEG")
|
| 29 |
+
img_str=buffered.getvalue()
|
| 30 |
+
return Response(content=img_str, media_type="image/jpeg")
|
model/model_epoch_49.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d03f06fc9cf54ceaf33ad68481c41c3a281137f262d356898aadafcd6ec9ae85
|
| 3 |
+
size 197469247
|