logan-codes commited on
Commit
7af3f84
ยท
1 Parent(s): ab81a26

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +90 -16
main.py CHANGED
@@ -1,4 +1,6 @@
1
  import os
 
 
2
  from pathlib import Path
3
 
4
  import gradio as gr
@@ -7,30 +9,58 @@ from basicsr.archs.rrdbnet_arch import RRDBNet
7
  from PIL import Image
8
  from realesrgan import RealESRGANer
9
 
 
 
 
10
  MODEL_FILENAME = "RealESRGAN_x4plus.pth"
11
  MODEL_SCALE = 4
12
  SUPPORTED_SCALES = (2, 4)
13
 
14
 
 
 
 
15
  def resolve_model_path() -> str:
16
  project_root = Path(__file__).resolve().parent
17
- candidate_paths = [project_root / "weights" / MODEL_FILENAME]
 
18
 
19
- home_dir = os.environ.get("HOME")
20
- if home_dir:
21
- candidate_paths.append(Path(home_dir) / "weights" / MODEL_FILENAME)
22
 
23
- for candidate in candidate_paths:
24
- if candidate.exists():
25
- return str(candidate)
26
 
27
- checked_locations = ", ".join(str(path) for path in candidate_paths)
28
- raise FileNotFoundError(
29
- f"Could not find {MODEL_FILENAME}. Checked: {checked_locations}"
30
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
 
32
 
 
 
 
 
 
 
33
  def load_model():
 
 
34
  model = RRDBNet(
35
  num_in_ch=3,
36
  num_out_ch=3,
@@ -39,36 +69,70 @@ def load_model():
39
  num_grow_ch=32,
40
  scale=MODEL_SCALE,
41
  )
42
- return RealESRGANer(
 
 
 
 
 
 
43
  scale=MODEL_SCALE,
44
- model_path=resolve_model_path(),
45
  model=model,
46
- half=False,
47
  )
48
 
 
 
 
 
49
 
 
 
 
50
  upsampler_cache = None
51
 
52
 
53
  def get_upsampler():
54
  global upsampler_cache
55
  if upsampler_cache is None:
 
56
  upsampler_cache = load_model()
 
 
 
57
  return upsampler_cache
58
 
59
 
 
 
 
60
  def upscale(image: Image.Image, scale: int):
 
 
61
  if image is None:
 
62
  raise gr.Error("Please upload an image first.")
63
 
64
  if scale not in SUPPORTED_SCALES:
 
65
  raise gr.Error(f"Unsupported upscale factor: {scale}")
66
 
 
 
67
  upsampler = get_upsampler()
 
 
68
  output, _ = upsampler.enhance(np.array(image), outscale=scale)
 
 
 
69
  return Image.fromarray(output)
70
 
71
 
 
 
 
72
  def build_demo():
73
  with gr.Blocks(title="AI Image Upscaler") as app:
74
  gr.Markdown("## ๐Ÿ” AI Image Upscaler\nPowered by Real-ESRGAN")
@@ -77,7 +141,9 @@ def build_demo():
77
  with gr.Column():
78
  input_img = gr.Image(type="pil", label="Input Image")
79
  scale_choice = gr.Radio(
80
- choices=list(SUPPORTED_SCALES), value=4, label="Upscale Factor"
 
 
81
  )
82
  btn = gr.Button("Upscale", variant="primary")
83
 
@@ -85,10 +151,18 @@ def build_demo():
85
  output_img = gr.Image(type="pil", label="Upscaled Output")
86
 
87
  btn.click(fn=upscale, inputs=[input_img, scale_choice], outputs=output_img)
 
88
  return app
89
 
90
 
 
 
 
91
  demo = build_demo()
92
 
93
  if __name__ == "__main__":
94
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
1
  import os
2
+ import sys
3
+ import requests
4
  from pathlib import Path
5
 
6
  import gradio as gr
 
9
  from PIL import Image
10
  from realesrgan import RealESRGANer
11
 
12
+ # Flush logs immediately (important for HF Spaces)
13
+ sys.stdout.reconfigure(line_buffering=True)
14
+
15
  MODEL_FILENAME = "RealESRGAN_x4plus.pth"
16
  MODEL_SCALE = 4
17
  SUPPORTED_SCALES = (2, 4)
18
 
19
 
20
+ # -------------------------------
21
+ # Model Path + Download Handling
22
+ # -------------------------------
23
  def resolve_model_path() -> str:
24
  project_root = Path(__file__).resolve().parent
25
+ model_dir = project_root / "weights"
26
+ model_dir.mkdir(exist_ok=True)
27
 
28
+ model_path = model_dir / MODEL_FILENAME
 
 
29
 
30
+ print(f"[INFO] Looking for model at: {model_path}")
 
 
31
 
32
+ if model_path.exists():
33
+ print("[SUCCESS] Model already exists. Skipping download.")
34
+ else:
35
+ print("[INFO] Model not found. Starting download...")
36
+
37
+ url = f"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/{MODEL_FILENAME}"
38
+ response = requests.get(url, stream=True)
39
+
40
+ total_size = int(response.headers.get("content-length", 0))
41
+ downloaded = 0
42
+
43
+ with open(model_path, "wb") as f:
44
+ for chunk in response.iter_content(chunk_size=8192):
45
+ if chunk:
46
+ f.write(chunk)
47
+ downloaded += len(chunk)
48
+
49
+ if total_size > 0:
50
+ percent = (downloaded / total_size) * 100
51
+ print(f"[DOWNLOAD] {percent:.2f}%")
52
 
53
+ print("[SUCCESS] Model downloaded successfully!")
54
 
55
+ return str(model_path)
56
+
57
+
58
+ # -------------------------------
59
+ # Load Model
60
+ # -------------------------------
61
  def load_model():
62
+ print("[INFO] Initializing model architecture...")
63
+
64
  model = RRDBNet(
65
  num_in_ch=3,
66
  num_out_ch=3,
 
69
  num_grow_ch=32,
70
  scale=MODEL_SCALE,
71
  )
72
+
73
+ print("[INFO] Resolving model path...")
74
+ model_path = resolve_model_path()
75
+
76
+ print(f"[INFO] Loading model weights from: {model_path}")
77
+
78
+ upsampler = RealESRGANer(
79
  scale=MODEL_SCALE,
80
+ model_path=model_path,
81
  model=model,
82
+ half=False, # set True if GPU available
83
  )
84
 
85
+ print("[SUCCESS] Model loaded successfully!")
86
+
87
+ return upsampler
88
+
89
 
90
+ # -------------------------------
91
+ # Cache Model
92
+ # -------------------------------
93
  upsampler_cache = None
94
 
95
 
96
  def get_upsampler():
97
  global upsampler_cache
98
  if upsampler_cache is None:
99
+ print("[INFO] No cached model found. Loading...")
100
  upsampler_cache = load_model()
101
+ else:
102
+ print("[INFO] Using cached model.")
103
+
104
  return upsampler_cache
105
 
106
 
107
+ # -------------------------------
108
+ # Upscale Function
109
+ # -------------------------------
110
  def upscale(image: Image.Image, scale: int):
111
+ print("[INFO] Upscale request received")
112
+
113
  if image is None:
114
+ print("[ERROR] No image provided")
115
  raise gr.Error("Please upload an image first.")
116
 
117
  if scale not in SUPPORTED_SCALES:
118
+ print(f"[ERROR] Unsupported scale: {scale}")
119
  raise gr.Error(f"Unsupported upscale factor: {scale}")
120
 
121
+ print(f"[INFO] Using scale: {scale}")
122
+
123
  upsampler = get_upsampler()
124
+
125
+ print("[INFO] Starting image enhancement...")
126
  output, _ = upsampler.enhance(np.array(image), outscale=scale)
127
+
128
+ print("[SUCCESS] Upscaling completed!")
129
+
130
  return Image.fromarray(output)
131
 
132
 
133
+ # -------------------------------
134
+ # Gradio UI
135
+ # -------------------------------
136
  def build_demo():
137
  with gr.Blocks(title="AI Image Upscaler") as app:
138
  gr.Markdown("## ๐Ÿ” AI Image Upscaler\nPowered by Real-ESRGAN")
 
141
  with gr.Column():
142
  input_img = gr.Image(type="pil", label="Input Image")
143
  scale_choice = gr.Radio(
144
+ choices=list(SUPPORTED_SCALES),
145
+ value=4,
146
+ label="Upscale Factor",
147
  )
148
  btn = gr.Button("Upscale", variant="primary")
149
 
 
151
  output_img = gr.Image(type="pil", label="Upscaled Output")
152
 
153
  btn.click(fn=upscale, inputs=[input_img, scale_choice], outputs=output_img)
154
+
155
  return app
156
 
157
 
158
+ # -------------------------------
159
+ # Run App
160
+ # -------------------------------
161
  demo = build_demo()
162
 
163
  if __name__ == "__main__":
164
+ demo.launch(
165
+ server_name="0.0.0.0",
166
+ server_port=7860,
167
+ queue=True # prevents crashes under load
168
+ )