dkescape commited on
Commit
ff3bcec
·
verified ·
1 Parent(s): d3dba73

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +168 -0
app.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import warnings
3
+ import logging
4
+
5
+ # ----------------------------
6
+ # 1. Warning & logging setup
7
+ # ----------------------------
8
+ # Suppress FutureWarning from timm internals
9
+ warnings.filterwarnings(
10
+ "ignore",
11
+ category=FutureWarning,
12
+ module="timm.models.layers"
13
+ )
14
+ # Suppress UserWarning from modelscope (e.g. missing preprocessor config)
15
+ warnings.filterwarnings(
16
+ "ignore",
17
+ category=UserWarning,
18
+ module="modelscope"
19
+ )
20
+ # Only show ERROR+ logs from modelscope
21
+ logging.getLogger("modelscope").setLevel(logging.ERROR)
22
+
23
+ # ----------------------------
24
+ # 2. Standard imports
25
+ # ----------------------------
26
+ import cv2
27
+ import tempfile
28
+ import gradio as gr
29
+ import numpy as np
30
+ from PIL import Image, ImageEnhance, ImageFilter
31
+ from modelscope.outputs import OutputKeys
32
+ from modelscope.pipelines import pipeline
33
+ from modelscope.utils.constant import Tasks
34
+
35
+ # ----------------------------
36
+ # 3. Load your colorization model
37
+ # ----------------------------
38
+ img_colorization = pipeline(
39
+ Tasks.image_colorization,
40
+ model="iic/cv_ddcolor_image-colorization",
41
+ model_revision="v1.02", # explicitly specify revision
42
+ )
43
+
44
+ # ----------------------------
45
+ # 4. Image processing fns
46
+ # ----------------------------
47
+ def colorize_image(img_path: str) -> str:
48
+ image = cv2.imread(str(img_path))
49
+ output = img_colorization(image[..., ::-1])
50
+ result = output[OutputKeys.OUTPUT_IMG].astype(np.uint8)
51
+
52
+ temp_dir = tempfile.mkdtemp()
53
+ out_path = os.path.join(temp_dir, "colorized.png")
54
+ cv2.imwrite(out_path, result)
55
+ return out_path
56
+
57
+ def enhance_image(
58
+ img_path: str,
59
+ brightness: float = 1.0,
60
+ contrast: float = 1.0,
61
+ edge_enhance: bool = False
62
+ ) -> str:
63
+ image = Image.open(img_path)
64
+ image = ImageEnhance.Brightness(image).enhance(brightness)
65
+ image = ImageEnhance.Contrast(image).enhance(contrast)
66
+ if edge_enhance:
67
+ image = image.filter(ImageFilter.EDGE_ENHANCE)
68
+
69
+ temp_dir = tempfile.mkdtemp()
70
+ enhanced_path = os.path.join(temp_dir, "enhanced.png")
71
+ image.save(enhanced_path)
72
+ return enhanced_path
73
+
74
+ def process_image(
75
+ img_path: str,
76
+ brightness: float,
77
+ contrast: float,
78
+ edge_enhance: bool,
79
+ output_format: str
80
+ ):
81
+ # Colorize → Enhance → Re‑save in chosen format
82
+ colorized_path = colorize_image(img_path)
83
+ enhanced_path = enhance_image(colorized_path, brightness, contrast, edge_enhance)
84
+
85
+ img = Image.open(enhanced_path)
86
+ temp_dir = tempfile.mkdtemp()
87
+ filename = f"colorized_image.{output_format.lower()}"
88
+ output_path = os.path.join(temp_dir, filename)
89
+ img.save(output_path, format=output_format.upper())
90
+
91
+ # Return side-by-side gallery and downloadable file
92
+ return ([img_path, enhanced_path], output_path)
93
+
94
+ # ----------------------------
95
+ # 5. Gradio UI + custom CSS
96
+ # ----------------------------
97
+ custom_css = """
98
+ body { background-color: #f0f2f5; }
99
+ .gradio-container { max-width: 900px !important; margin: auto !important; }
100
+ #header { background-color: #4CAF50; padding: 20px; border-radius: 8px;
101
+ text-align: center; margin-bottom: 20px; }
102
+ #header h2, #header p { color: white; margin: 0; }
103
+ #header p { margin-top: 5px; font-size: 1rem; }
104
+ #control-panel { background: white; padding: 20px; border-radius: 8px;
105
+ box-shadow: 0 2px 8px rgba(0,0,0,0.1); margin-bottom: 20px; }
106
+ #submit-btn { background-color: #4CAF50 !important; color: white !important;
107
+ border-radius: 8px !important; font-weight: bold;
108
+ padding: 10px 20px !important; margin-top: 10px !important; }
109
+ #control-panel .gr-row { gap: 15px; }
110
+ .gr-slider, .gr-checkbox, .gr-dropdown { margin-top: 10px; }
111
+ #comparison_gallery { background: white; padding: 10px;
112
+ border-radius: 8px; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }
113
+ #download-btn { margin-top: 15px !important; }
114
+ """
115
+
116
+ TITLE = "🌈 Color Restorization Model"
117
+ DESCRIPTION = "Bring your old black & white photos back to life—upload, adjust, and download in vivid color."
118
+
119
+ with gr.Blocks(title=TITLE, css=custom_css) as app:
120
+ # Header
121
+ gr.HTML(
122
+ """
123
+ <div id="header">
124
+ <h2>🌈 Color Restorization Model</h2>
125
+ <p>Bring your old black & white photos back to life—upload, adjust, and download in vivid color.</p>
126
+ </div>
127
+ """
128
+ )
129
+
130
+ # Controls & results
131
+ with gr.Column(elem_id="control-panel"):
132
+ with gr.Row():
133
+ # Inputs
134
+ with gr.Column():
135
+ input_image = gr.Image(type="filepath", label="Upload B&W Image", interactive=True)
136
+ brightness_slider = gr.Slider(0.5, 2.0, value=1.0, label="Brightness")
137
+ contrast_slider = gr.Slider(0.5, 2.0, value=1.0, label="Contrast")
138
+ edge_enhance_checkbox = gr.Checkbox(label="Apply Edge Enhancement")
139
+ output_format_dropdown = gr.Dropdown(["PNG", "JPEG", "TIFF"], value="PNG", label="Output Format")
140
+ submit_btn = gr.Button("Colorize", elem_id="submit-btn")
141
+
142
+ # Outputs
143
+ with gr.Column():
144
+ comparison_gallery = gr.Gallery(
145
+ label="Original vs. Colorized",
146
+ columns=2,
147
+ elem_id="comparison_gallery",
148
+ height="auto"
149
+ )
150
+ download_btn = gr.File(label="Download Colorized Image", elem_id="download-btn")
151
+
152
+ # Wire up
153
+ submit_btn.click(
154
+ fn=process_image,
155
+ inputs=[
156
+ input_image,
157
+ brightness_slider,
158
+ contrast_slider,
159
+ edge_enhance_checkbox,
160
+ output_format_dropdown
161
+ ],
162
+ outputs=[comparison_gallery, download_btn]
163
+ )
164
+
165
+ # Launch
166
+ if __name__ == "__main__":
167
+ port = int(os.environ.get("PORT", 7860))
168
+ app.queue().launch(server_name="0.0.0.0", server_port=port)