Pushpesh commited on
Commit
fdc7f23
·
1 Parent(s): d43a4a2

Final fixes

Browse files
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.13
app.py CHANGED
@@ -4,55 +4,44 @@ from PIL import Image
4
  import numpy as np
5
  from app.utils import recover_light_sources
6
  from model.model import model
 
 
7
 
8
  device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
  chk=torch.load('model/model_epoch_49.pth',map_location=device)
10
  model.load_state_dict(chk['model_state_dict'])
11
 
12
- #def evaluate(model,image):
13
- # model.eval()
14
- # with torch.no_grad():
15
- # image = image.to(device)
16
- # #outputs= model(image.unsqueeze(0))
17
- # outputs= model(image)
18
- # return outputs.squeeze(0).squeeze(0).cpu()
19
-
20
- #def predict(input_image):
21
- #input_image=Image.open(inp_img).convert('RGB')
22
- # input_image=input_image.resize((512,512))
23
- # input_image_torch=torch.tensor(np.array(input_image)).permute(2,0,1).unsqueeze(0).float()/255.0
24
- # mask=evaluate(model,input_image_torch)
25
- # mask=mask.permute(1,2,0).numpy()
26
- # final_img=recover_light_sources(input_image,mask)
27
- # return final_img
28
-
29
- def evaluate(model, image_tensor):
30
  """
31
  Run the model on the given image tensor and return output as numpy array (H,W,C) in [0,255].
32
  """
 
33
  with torch.no_grad():
34
- image_tensor = image_tensor.to(device, dtype=torch.float32)
35
- outputs = model(image_tensor)
36
  outputs = torch.clamp(outputs, 0.0, 1.0)
37
  outputs_np = outputs.squeeze(0).permute(1, 2, 0).cpu().numpy()
38
- return (outputs_np * 255).astype(np.uint8) # Convert to uint8
39
 
40
  def predict(input_image):
41
  """
42
  Predict clean image from flare image, then recover light sources.
43
  """
44
- # Ensure input image is in RGB
45
- if input_image.mode != "RGB":
46
- input_image = input_image.convert("RGB")
47
  # Resize and prepare input tensor
48
- input_image_resized = input_image.resize((512, 512))
49
- input_np = np.array(input_image_resized) / 255.0 # scale to [0, 1]
50
- input_tensor = torch.from_numpy(input_np).permute(2, 0, 1).unsqueeze(0).float()
51
- # Get predicted clean image from model
52
- pred_clean_img = evaluate(model, input_tensor)
53
- # Return as PIL Image
54
- return Image.fromarray(pred_clean_img)
55
 
 
 
 
 
 
 
 
56
 
57
  demo = gr.Interface(fn=predict, inputs=gr.Image(type="pil"),outputs=gr.Image(), examples=["test_imgs/test1.png", "test_imgs/test2.png","test_imgs/test3.png"])
58
  demo.launch()
 
4
  import numpy as np
5
  from app.utils import recover_light_sources
6
  from model.model import model
7
+ from torchvision import transforms
8
+
9
 
10
  device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
  chk=torch.load('model/model_epoch_49.pth',map_location=device)
12
  model.load_state_dict(chk['model_state_dict'])
13
 
14
+ transform = transforms.Compose([
15
+ transforms.Resize((512, 512)),
16
+ transforms.ToTensor(), # -> (C,H,W), dtype=float32, range [0,1]
17
+ ])
18
+
19
+ def evaluate(model, image):
 
 
 
 
 
 
 
 
 
 
 
 
20
  """
21
  Run the model on the given image tensor and return output as numpy array (H,W,C) in [0,255].
22
  """
23
+ model.eval()
24
  with torch.no_grad():
25
+ #image = image.to(device, dtype=torch.float32)
26
+ outputs = model(image)
27
  outputs = torch.clamp(outputs, 0.0, 1.0)
28
  outputs_np = outputs.squeeze(0).permute(1, 2, 0).cpu().numpy()
29
+ return (outputs_np * 255).astype(np.uint8) # Convert to uint8 for recovery step
30
 
31
  def predict(input_image):
32
  """
33
  Predict clean image from flare image, then recover light sources.
34
  """
 
 
 
35
  # Resize and prepare input tensor
36
+ input_img = input_image.convert('RGB').resize((512, 512), Image.BILINEAR)
 
 
 
 
 
 
37
 
38
+ input_tensor = transform(input_img).unsqueeze(0).to(device, dtype=torch.float32)
39
+
40
+ # Get predicted clean image from model
41
+ pred_clean_img = evaluate(model, input_tensor) # uint8 predicted clean
42
+ # Recover light sources
43
+ final_img = recover_light_sources(network_output=pred_clean_img,original_image=input_img)
44
+ return final_img
45
 
46
  demo = gr.Interface(fn=predict, inputs=gr.Image(type="pil"),outputs=gr.Image(), examples=["test_imgs/test1.png", "test_imgs/test2.png","test_imgs/test3.png"])
47
  demo.launch()
app/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (165 Bytes). View file
 
app/__pycache__/utils.cpython-313.pyc ADDED
Binary file (3.34 kB). View file
 
main.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ def main():
2
+ print("Hello from flare-removal-2-0!")
3
+
4
+
5
+ if __name__ == "__main__":
6
+ main()
model/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (167 Bytes). View file
 
model/__pycache__/model.cpython-313.pyc ADDED
Binary file (383 Bytes). View file
 
model/model_epoch_49.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d03f06fc9cf54ceaf33ad68481c41c3a281137f262d356898aadafcd6ec9ae85
3
- size 197469247
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6e1d2df83637ae40a4569e93f8135dcd248d49595d9cba685e16aa7f3aec798e
3
+ size 197467007
pyproject.toml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "flare-removal-2-0"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.13"
7
+ dependencies = [
8
+ "fastapi>=0.116.1",
9
+ "gradio>=5.44.1",
10
+ "numpy>=2.3.2",
11
+ "pillow>=11.3.0",
12
+ "pydantic>=2.11.7",
13
+ "segmentation-models-pytorch>=0.5.0",
14
+ "torch>=2.8.0",
15
+ "uvicorn>=0.35.0",
16
+ ]
test_imgs/test1.png CHANGED

Git LFS Details

  • SHA256: 6b273dc4e082b2350aa744ba45ddc290a1f64018477068abfdefa7e0be2c02ef
  • Pointer size: 131 Bytes
  • Size of remote file: 432 kB

Git LFS Details

  • SHA256: 952f22cccb09e622884951c753fa808dc59b89902f1c2fa92ad3f151218cc2ff
  • Pointer size: 131 Bytes
  • Size of remote file: 376 kB
test_imgs/test2.png CHANGED

Git LFS Details

  • SHA256: 3b1d4205c8b2121fee0f59da948af28bf1810a215bd8f507b38a46a88fafa844
  • Pointer size: 131 Bytes
  • Size of remote file: 537 kB

Git LFS Details

  • SHA256: 1feaa374a1795197a311645590870e41faa0e1095d70696b95f3c35651f05908
  • Pointer size: 131 Bytes
  • Size of remote file: 438 kB
test_imgs/test3.png CHANGED

Git LFS Details

  • SHA256: d6e1d50a78728037d4157e9d9bbad36fe800dfcd2b881f00e8cca84c689ec97f
  • Pointer size: 131 Bytes
  • Size of remote file: 383 kB

Git LFS Details

  • SHA256: 3a5352207df38e0153c1c02266ca6c24f3fa1e0fc1fec49175dc1c465c4be608
  • Pointer size: 131 Bytes
  • Size of remote file: 330 kB
uv.lock ADDED
The diff for this file is too large to render. See raw diff