r3gm commited on
Commit
d5dba13
·
verified ·
1 Parent(s): 072260f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +388 -0
app.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import random
4
+ import os
5
+ from PIL import Image
6
+ from huggingface_hub import HfApi
7
+ from io import StringIO
8
+ from transformers import pipeline
9
+ import torch
10
+ import requests # Needed for downloading models
11
+ from tqdm import tqdm # For download progress bar
12
+ import spaces
13
+
14
+ # --- New Official Implementation Imports ---
15
+ from stablepy import load_upscaler_model
16
+
17
+ # --- New Global Constants ---
18
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
19
+ DIRECTORY_UPSCALERS = "upscalers"
20
+
21
+ # --- Configuration ---
22
+ # Set your Hugging Face Write Token as an environment variable
23
+ # export HF_TOKEN_ORG="hf_YourTokenHere"
24
+ HF_TOKEN_ORG = os.getenv("HF_TOKEN_ORG")
25
+ DATASET_REPO_ID = "TestOrganizationPleaseIgnore/test"
26
+ DATASET_FILENAME = "upscaler_preferences.csv"
27
+ LOCAL_CSV_PATH = "upscaler_preferences_local.csv" # Local backup for safety
28
+ PUSH_THRESHOLD = 10 # Push after this many new votes
29
+
30
+ # This dictionary remains as a global constant as it's static configuration
31
+ UPSCALER_DICT_GUI = {
32
+ "RealESRGAN_x4plus": "https://huggingface.co/ai-forever/Real-ESRGAN/resolve/main/RealESRGAN_x4.pth",
33
+ "RealESRGAN_x2plus": "https://huggingface.co/ai-forever/Real-ESRGAN/resolve/main/RealESRGAN_x2.pth",
34
+ "SwinIR_x4": "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/001_classicalSR_DIV2K_s48w8_SwinIR-M_x4.pth",
35
+ "BSRGAN_x2": "https://huggingface.co/glassful/models/resolve/main/BSRGANx2.pth",
36
+ "NewModel_x4_beta": "path/to/new_model.pth" # Example of a local model
37
+ }
38
+
39
+ # --- Helper Functions for New Implementation ---
40
+ def download_model(directory, url):
41
+ """Downloads a file from a URL to a specified directory with a progress bar."""
42
+ if not os.path.exists(directory):
43
+ os.makedirs(directory)
44
+ print(f"Created directory: {directory}")
45
+
46
+ filename = url.split('/')[-1]
47
+ filepath = os.path.join(directory, filename)
48
+
49
+ if os.path.exists(filepath):
50
+ print(f"Model '{filename}' already exists. Skipping download.")
51
+ return filepath
52
+
53
+ try:
54
+ print(f"Downloading model '{filename}' from {url}...")
55
+ response = requests.get(url, stream=True)
56
+ response.raise_for_status() # Raise an exception for bad status codes
57
+ total_size_in_bytes = int(response.headers.get('content-length', 0))
58
+ block_size = 1024 # 1 Kibibyte
59
+
60
+ with tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True, desc=f"Downloading {filename}") as progress_bar:
61
+ with open(filepath, 'wb') as file:
62
+ for data in response.iter_content(block_size):
63
+ progress_bar.update(len(data))
64
+ file.write(data)
65
+
66
+ if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
67
+ print("ERROR, something went wrong during download.")
68
+ return None
69
+
70
+ print(f"Model '{filename}' downloaded successfully to '{filepath}'.")
71
+ return filepath
72
+ except requests.exceptions.RequestException as e:
73
+ print(f"Error downloading model: {e}")
74
+ return None
75
+
76
+ def extract_exif_data(image):
77
+ """Placeholder function to extract EXIF data. Can be expanded later."""
78
+ # In a real implementation, you would use a library like piexif
79
+ # and return the exif bytes. For now, it does nothing.
80
+ return None
81
+
82
+
83
+ class UpscalerApp:
84
+ def __init__(self, repo_id, filename, local_path, push_threshold):
85
+ """
86
+ Initializes the application, loads data, and sets up state.
87
+ """
88
+ self.repo_id = repo_id
89
+ self.filename = filename
90
+ self.local_path = local_path
91
+ self.push_threshold = push_threshold
92
+
93
+ self.results_df = None
94
+ self.new_votes_count = 0
95
+
96
+ # Initialize the image classifier on the correct device (GPU or CPU)
97
+ print(f"Initializing classifier on device: {DEVICE}")
98
+ self.classifier = pipeline(
99
+ "zero-shot-image-classification",
100
+ model="laion/CLIP-ViT-L-14-laion2B-s32B-b82K",
101
+ device=DEVICE
102
+ )
103
+ self.disambiguation_dict = {
104
+ "Modern photo or photorealistic CGI": "modern_photo_cgi",
105
+ "Old vintage photograph": "vintage_photo",
106
+ "Anime illustration": "anime_illustration",
107
+ "Manga": "manga",
108
+ "Cartoon, Comic book": "cartoon_comic",
109
+ "In-game screenshot with heads-up display HUD or UI elements": "in_game_screenshot_hud",
110
+ "Pixel art or low-resolution retro graphics": "pixel_art_retro",
111
+ "Text document or code": "text_document_code"
112
+ }
113
+ self.candidate_labels = list(self.disambiguation_dict.keys())
114
+
115
+ self.initialize_dataset()
116
+ self.ui = self.build_gradio_ui()
117
+
118
+ def initialize_dataset(self):
119
+ """
120
+ Loads the dataset from the Hub, falling back to a local file,
121
+ and finally creating a new one if necessary.
122
+ """
123
+ if HF_TOKEN_ORG is None:
124
+ print("WARNING: HF_TOKEN_ORG not set. Results will only be saved locally.")
125
+
126
+ # 1. Try to load from Hugging Face Hub first
127
+ try:
128
+ api = HfApi()
129
+ file_path = api.hf_hub_download(
130
+ repo_id=self.repo_id,
131
+ filename=self.filename,
132
+ repo_type="dataset",
133
+ token=HF_TOKEN_ORG
134
+ )
135
+ self.results_df = pd.read_csv(file_path).set_index("model_name")
136
+ print(f"Successfully loaded results from '{self.repo_id}'.")
137
+ except Exception as e:
138
+ print(f"Could not load from Hub (may not exist yet): {e}")
139
+ # 2. If Hub fails, try to load from local backup
140
+ if os.path.exists(self.local_path):
141
+ print(f"Loading results from local file: '{self.local_path}'")
142
+ self.results_df = pd.read_csv(self.local_path).set_index("model_name")
143
+ else:
144
+ # 3. If no local file, create a new DataFrame
145
+ print("No local CSV found. Creating a new preference count DataFrame.")
146
+ model_names = list(UPSCALER_DICT_GUI.keys())
147
+ columns = ['model_name', 'count'] + list(self.disambiguation_dict.values())
148
+ self.results_df = pd.DataFrame(columns=columns).set_index('model_name')
149
+
150
+
151
+ # Ensure all current models and columns exist in the DataFrame
152
+ for model in UPSCALER_DICT_GUI:
153
+ if model not in self.results_df.index:
154
+ print(f"Adding new model '{model}' to the DataFrame.")
155
+ self.results_df.loc[model] = 0
156
+
157
+ for col in list(self.disambiguation_dict.values()):
158
+ if col not in self.results_df.columns:
159
+ self.results_df[col] = 0
160
+
161
+ # Save a clean local copy on startup
162
+ self.save_results_to_local_csv()
163
+
164
+
165
+ def push_results_to_hub(self):
166
+ """
167
+ Pushes the current results DataFrame to the Hugging Face Hub.
168
+ This is a BLOCKING operation and will freeze the UI.
169
+ """
170
+ if HF_TOKEN_ORG is None:
171
+ print("Skipping push: HF_TOKEN_ORG not available.")
172
+ return
173
+
174
+ if self.results_df is None or self.results_df.empty:
175
+ return
176
+
177
+ print(f"Blocking UI to push results to '{self.repo_id}'...")
178
+ try:
179
+ csv_buffer = StringIO()
180
+ # reset_index() makes 'model_name' a column again before saving
181
+ self.results_df.reset_index().to_csv(csv_buffer, index=False)
182
+
183
+ api = HfApi()
184
+ api.upload_file(
185
+ path_or_fileobj=csv_buffer.getvalue().encode("utf-8"),
186
+ path_in_repo=self.filename,
187
+ repo_id=self.repo_id,
188
+ repo_type="dataset",
189
+ token=HF_TOKEN_ORG,
190
+ commit_message="Automated preference count update"
191
+ )
192
+ print("Successfully pushed updated results to the Hub.")
193
+ except Exception as e:
194
+ print(f"Error pushing results to the Hub: {e}")
195
+
196
+ def save_results_to_local_csv(self):
197
+ """Saves the current DataFrame to a local CSV file for persistence."""
198
+ if self.results_df is not None:
199
+ self.results_df.reset_index().to_csv(self.local_path, index=False)
200
+
201
+ # --- Official upscale function ---
202
+ def process_upscale(self, image, upscaler_name, upscaler_size, tile, tile_overlap, half):
203
+ """
204
+ Processes an image using the specified upscaler model and settings.
205
+ """
206
+ if image is None:
207
+ return None
208
+
209
+ print(f"Upscaling with: {upscaler_name}, Size: {upscaler_size}, Tile: {tile}, Overlap: {tile_overlap}, Half: {half}")
210
+
211
+ image = image.convert("RGB")
212
+ # exif_image = extract_exif_data(image) # Placeholder for future use
213
+
214
+ model_path = UPSCALER_DICT_GUI[upscaler_name]
215
+
216
+ # Check if the model is a URL and download it if it doesn't exist locally
217
+ if "https://" in str(model_path) or "http://" in str(model_path):
218
+ local_model_path = download_model(DIRECTORY_UPSCALERS, model_path)
219
+ if local_model_path is None:
220
+ # Handle download failure
221
+ gr.Warning("Failed to download the upscaler model. Please check the console for errors.")
222
+ return None
223
+ model_path = local_model_path
224
+
225
+ elif not os.path.exists(model_path):
226
+ gr.Warning(f"Local model file not found at: {model_path}")
227
+ return None
228
+
229
+
230
+ # Load the upscaler model with specified tile and precision settings
231
+ scaler_beta = load_upscaler_model(model=model_path, tile=tile, tile_overlap=tile_overlap, device=DEVICE, half=half)
232
+
233
+ # Perform the upscale
234
+ image_up = scaler_beta.upscale(image, upscaler_size, True)
235
+
236
+ return image_up
237
+
238
+ # --- Gradio Callback Functions ---
239
+ def blind_upscale(self, image, upscaler_size, tile, tile_overlap, half):
240
+ if image is None:
241
+ return None, None, "Please upload an image.", "", "", "", gr.Button(interactive=False), gr.Button(interactive=False)
242
+
243
+ # Classify the image
244
+ predictions = self.classifier(image, candidate_labels=self.candidate_labels)
245
+ top_prediction_label = predictions[0]['label']
246
+ top_prediction_key = self.disambiguation_dict[top_prediction_label]
247
+
248
+ model_keys = list(UPSCALER_DICT_GUI.keys())
249
+ if len(model_keys) < 2:
250
+ return None, None, "Not enough models to compare.", "", "", "", gr.Button(interactive=False), gr.Button(interactive=False)
251
+
252
+ model_a_name, model_b_name = random.sample(model_keys, 2)
253
+
254
+ # Process both images with the same settings from the UI
255
+ upscaled_a = self.process_upscale(image, model_a_name, upscaler_size, tile, tile_overlap, half)
256
+ upscaled_b = self.process_upscale(image, model_b_name, upscaler_size, tile, tile_overlap, half)
257
+
258
+ if upscaled_a is None or upscaled_b is None:
259
+ # Handle case where upscaling failed (e.g., model download error)
260
+ return None, None, "Upscaling failed. Check console for details.", "", "", "", gr.Button(interactive=False), gr.Button(interactive=False)
261
+
262
+ result_text = f"Image classified as: **{top_prediction_label}**. Which result do you prefer?"
263
+
264
+ return upscaled_a, upscaled_b, result_text, model_a_name, model_b_name, top_prediction_key, gr.Button(interactive=True), gr.Button(interactive=True)
265
+
266
+ def handle_choice(self, choice, model_a, model_b, image_category):
267
+ if not model_a or not model_b:
268
+ return "Please start a comparison first.", gr.Button(interactive=False), gr.Button(interactive=False)
269
+
270
+ winner = model_a if choice == "Result A" else model_b
271
+
272
+ if winner not in self.results_df.index:
273
+ self.results_df.loc[winner] = 0
274
+
275
+ # Increment the main count and the category-specific count
276
+ self.results_df.loc[winner, 'count'] += 1
277
+ if image_category in self.results_df.columns:
278
+ self.results_df.loc[winner, image_category] += 1
279
+
280
+ new_count = self.results_df.loc[winner, 'count']
281
+ self.new_votes_count += 1
282
+
283
+ print(f"Recorded preference for '{winner}' in category '{image_category}'. New count: {new_count}. Total new votes: {self.new_votes_count}")
284
+
285
+ # Always save locally for safety
286
+ self.save_results_to_local_csv()
287
+
288
+ # If threshold is met, trigger a BLOCKING push
289
+ if self.new_votes_count >= self.push_threshold:
290
+ print(f"Vote threshold reached. Initiating blocking push to Hub...")
291
+ self.push_results_to_hub() # This is a direct, blocking call
292
+ self.new_votes_count = 0 # Reset counter
293
+
294
+ reveal_text = f"Thank you! Your preference for **{choice}** has been recorded.\n\n- **Image A was:** {model_a}\n- **Image B was:** {model_b}"
295
+ return reveal_text, gr.Button(interactive=False), gr.Button(interactive=False)
296
+
297
+ @spaces.GPU()
298
+ def playground_upscale(self, image, upscaler_name, upscaler_size, tile, tile_overlap, half):
299
+ if image is None or upscaler_name is None: return None
300
+ return self.process_upscale(image, upscaler_name, upscaler_size, tile, tile_overlap, half)
301
+
302
+ def build_gradio_ui(self):
303
+ """Constructs the Gradio interface."""
304
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
305
+ gr.Markdown("# Image Upscaler GUI with A/B Testing")
306
+
307
+ with gr.Accordion("Advanced Settings", open=True):
308
+ with gr.Row():
309
+ upscaler_size_slider = gr.Slider(minimum=1.1, maximum=4.0, value=2.0, step=0.1, label="Upscale Factor")
310
+ tile_slider = gr.Slider(minimum=0, maximum=1024, value=192, step=16, label="Tile Size (0 is not tile)")
311
+ tile_overlap_slider = gr.Slider(minimum=0, maximum=128, value=8, step=1, label="Tile Overlap")
312
+ half_checkbox = gr.Checkbox(label="Use Half Precision (FP16)", value=True)
313
+
314
+
315
+ with gr.Tab("Blind Test Comparison"):
316
+ gr.Markdown("Upload an image, compare the results, and select your favorite. Your vote is recorded to rank the models.")
317
+ gr.Markdown(
318
+ "> **Disclaimer:** This application **does not store your uploaded images**."
319
+ " It only anonymously records which upscaler you prefer so we can rank them."
320
+ )
321
+ model_a_state = gr.State("")
322
+ model_b_state = gr.State("")
323
+ image_category_state = gr.State("")
324
+ with gr.Row():
325
+ input_image_blind = gr.Image(type="pil", label="Source Image")
326
+ compare_button = gr.Button("Compare Upscalers")
327
+ with gr.Row():
328
+ output_image_a = gr.Image(label="Result A", interactive=False)
329
+ output_image_b = gr.Image(label="Result B", interactive=False)
330
+ with gr.Row():
331
+ choose_a_button = gr.Button("I prefer Result A", interactive=False)
332
+ choose_b_button = gr.Button("I prefer Result B", interactive=False)
333
+ result_text_blind = gr.Markdown("")
334
+
335
+ compare_button.click(
336
+ fn=self.blind_upscale,
337
+ inputs=[input_image_blind, upscaler_size_slider, tile_slider, tile_overlap_slider, half_checkbox],
338
+ outputs=[output_image_a, output_image_b, result_text_blind, model_a_state, model_b_state, image_category_state, choose_a_button, choose_b_button]
339
+ )
340
+ choose_a_button.click(
341
+ fn=lambda a, b, c: self.handle_choice("Result A", a, b, c),
342
+ inputs=[model_a_state, model_b_state, image_category_state],
343
+ outputs=[result_text_blind, choose_a_button, choose_b_button]
344
+ )
345
+ choose_b_button.click(
346
+ fn=lambda a, b, c: self.handle_choice("Result B", a, b, c),
347
+ inputs=[model_a_state, model_b_state, image_category_state],
348
+ outputs=[result_text_blind, choose_a_button, choose_b_button]
349
+ )
350
+
351
+ with gr.Tab("Upscaler Playground"):
352
+ gr.Markdown("Select an upscaler model, choose a scaling factor, and process your image.")
353
+ with gr.Row():
354
+ with gr.Column(scale=1):
355
+ input_image_playground = gr.Image(type="pil", label="Source Image")
356
+ upscaler_model_dropdown = gr.Dropdown(choices=list(UPSCALER_DICT_GUI.keys()), label="Upscaler Model")
357
+ run_button_playground = gr.Button("Run Upscale")
358
+ with gr.Column(scale=2):
359
+ output_image_playground = gr.Image(label="Upscaled Result", interactive=False)
360
+
361
+ run_button_playground.click(
362
+ fn=self.playground_upscale,
363
+ inputs=[input_image_playground, upscaler_model_dropdown, upscaler_size_slider, tile_slider, tile_overlap_slider, half_checkbox],
364
+ outputs=[output_image_playground]
365
+ )
366
+
367
+ return demo
368
+
369
+ def launch(self, **kwargs):
370
+ self.ui.launch(**kwargs)
371
+
372
+ @spaces.GPU
373
+ def dummy_gpu():
374
+ return None
375
+
376
+ # --- Main Execution Block ---
377
+ if __name__ == "__main__":
378
+ # Before launching, ensure the upscalers directory exists
379
+ if not os.path.exists(DIRECTORY_UPSCALERS):
380
+ os.makedirs(DIRECTORY_UPSCALERS)
381
+
382
+ app = UpscalerApp(
383
+ repo_id=DATASET_REPO_ID,
384
+ filename=DATASET_FILENAME,
385
+ local_path=LOCAL_CSV_PATH,
386
+ push_threshold=PUSH_THRESHOLD
387
+ )
388
+ app.launch(debug=True, show_error=True)