um41r commited on
Commit
f59a1e6
·
verified ·
1 Parent(s): 41e2ade

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +126 -67
app.py CHANGED
@@ -1,17 +1,25 @@
1
  import os
2
  import cv2
3
- import uuid
4
  import numpy
5
  import base64
6
  from io import BytesIO
7
  from PIL import Image
8
- from flask import Flask, request, jsonify
 
 
 
 
9
  from basicsr.archs.rrdbnet_arch import RRDBNet
10
  from basicsr.utils.download_util import load_file_from_url
11
  from realesrgan import RealESRGANer
12
  from realesrgan.archs.srvgg_arch import SRVGGNetCompact
13
 
14
- app = Flask(__name__)
 
 
 
 
 
15
 
16
  # Create weights directory if it doesn't exist
17
  os.makedirs('weights', exist_ok=True)
@@ -19,7 +27,59 @@ os.makedirs('weights', exist_ok=True)
19
  # Global variable to track image mode
20
  img_mode = "RGBA"
21
 
22
- def process_image(img_data, model_name, denoise_strength, face_enhance, outscale):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  """Real-ESRGAN function to restore (and upscale) images."""
24
  global img_mode
25
 
@@ -48,7 +108,7 @@ def process_image(img_data, model_name, denoise_strength, face_enhance, outscale
48
  'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
49
  ]
50
  else:
51
- return None, "Invalid model name"
52
 
53
  # Download model if not already available
54
  model_path = os.path.join('weights', model_name + '.pth')
@@ -113,7 +173,7 @@ def process_image(img_data, model_name, denoise_strength, face_enhance, outscale
113
  else:
114
  output, _ = upsampler.enhance(img, outscale=outscale)
115
  except RuntimeError as error:
116
- return None, f"Error: {str(error)}"
117
 
118
  # Convert back to appropriate format based on mode
119
  if img_mode == "RGBA":
@@ -133,30 +193,47 @@ def process_image(img_data, model_name, denoise_strength, face_enhance, outscale
133
 
134
  return output_img, properties
135
 
136
- @app.route('/enhancer', methods=['POST'])
137
- def enhance_image():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  try:
139
- # Check if file is present in request
140
- if 'image' not in request.files:
141
- return jsonify({"error": "No image file provided"}), 400
 
 
 
 
142
 
143
- # Get image file
144
- image_file = request.files['image']
 
145
 
146
- # Get parameters with defaults
147
- model_name = request.form.get('model', 'RealESRGAN_x4plus')
148
- denoise_strength = float(request.form.get('denoise_strength', 0.5))
149
- outscale = int(request.form.get('outscale', 4))
150
- face_enhance = request.form.get('face_enhance', 'false').lower() == 'true'
151
 
152
- # Open image using PIL
153
- img = Image.open(image_file)
 
154
 
155
  # Process image
156
- output_img, properties = process_image(img, model_name, denoise_strength, face_enhance, outscale)
157
-
158
- if output_img is None:
159
- return jsonify({"error": properties}), 500
160
 
161
  # Convert to PIL Image and then to base64
162
  output_pil = Image.fromarray(output_img)
@@ -172,53 +249,35 @@ def enhance_image():
172
  img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
173
 
174
  # Return response
175
- return jsonify({
176
  "enhanced_image": img_str,
177
  "properties": properties,
178
- "model_used": model_name
179
- })
180
 
 
 
181
  except Exception as e:
182
- return jsonify({"error": str(e)}), 500
183
-
184
- # Add a simple health check route
185
- @app.route('/', methods=['GET'])
186
- def health_check():
187
- return jsonify({"status": "healthy", "message": "Image enhancement server is running"})
188
-
189
- # Add a route to list available models
190
- @app.route('/models', methods=['GET'])
191
- def list_models():
192
- models = [
193
- {
194
- "name": "RealESRGAN_x4plus",
195
- "description": "General purpose 4x upscaling model",
196
- "scale": 4
197
- },
198
- {
199
- "name": "RealESRNet_x4plus",
200
- "description": "Alternative 4x upscaling model",
201
- "scale": 4
202
- },
203
- {
204
- "name": "RealESRGAN_x4plus_anime_6B",
205
- "description": "Specialized for anime/cartoon images, 4x upscaling",
206
- "scale": 4
207
- },
208
- {
209
- "name": "RealESRGAN_x2plus",
210
- "description": "2x upscaling model",
211
- "scale": 2
212
- },
213
- {
214
- "name": "realesr-general-x4v3",
215
- "description": "General purpose 4x upscaling model with denoise control",
216
- "scale": 4
217
- }
218
- ]
219
- return jsonify({"models": models})
220
 
221
  if __name__ == "__main__":
222
- # Run server with debug mode enabled
223
  port = int(os.environ.get("PORT", 8000))
224
- app.run(host="0.0.0.0", port=port)
 
1
  import os
2
  import cv2
 
3
  import numpy
4
  import base64
5
  from io import BytesIO
6
  from PIL import Image
7
+ from fastapi import FastAPI, File, Form, UploadFile, HTTPException
8
+ from fastapi.responses import JSONResponse
9
+ from typing import Optional, List, Dict, Any, Union
10
+ from pydantic import BaseModel
11
+ import uvicorn
12
  from basicsr.archs.rrdbnet_arch import RRDBNet
13
  from basicsr.utils.download_util import load_file_from_url
14
  from realesrgan import RealESRGANer
15
  from realesrgan.archs.srvgg_arch import SRVGGNetCompact
16
 
17
+ # Create FastAPI app
18
+ app = FastAPI(
19
+ title="Image Enhancement API",
20
+ description="API for enhancing and upscaling images using Real-ESRGAN models",
21
+ version="1.0.0"
22
+ )
23
 
24
  # Create weights directory if it doesn't exist
25
  os.makedirs('weights', exist_ok=True)
 
27
  # Global variable to track image mode
28
  img_mode = "RGBA"
29
 
30
+ # Models information
31
+ AVAILABLE_MODELS = [
32
+ {
33
+ "name": "RealESRGAN_x4plus",
34
+ "description": "General purpose 4x upscaling model",
35
+ "scale": 4
36
+ },
37
+ {
38
+ "name": "RealESRNet_x4plus",
39
+ "description": "Alternative 4x upscaling model",
40
+ "scale": 4
41
+ },
42
+ {
43
+ "name": "RealESRGAN_x4plus_anime_6B",
44
+ "description": "Specialized for anime/cartoon images, 4x upscaling",
45
+ "scale": 4
46
+ },
47
+ {
48
+ "name": "RealESRGAN_x2plus",
49
+ "description": "2x upscaling model",
50
+ "scale": 2
51
+ },
52
+ {
53
+ "name": "realesr-general-x4v3",
54
+ "description": "General purpose 4x upscaling model with denoise control",
55
+ "scale": 4
56
+ }
57
+ ]
58
+
59
+ # Pydantic models for API documentation
60
+ class HealthResponse(BaseModel):
61
+ status: str
62
+ message: str
63
+
64
+ class ModelInfo(BaseModel):
65
+ name: str
66
+ description: str
67
+ scale: int
68
+
69
+ class ModelsResponse(BaseModel):
70
+ models: List[ModelInfo]
71
+
72
+ class ImageProperties(BaseModel):
73
+ width: int
74
+ height: int
75
+ mode: str
76
+
77
+ class EnhancementResponse(BaseModel):
78
+ enhanced_image: str
79
+ properties: ImageProperties
80
+ model_used: str
81
+
82
+ async def process_image(img_data, model_name, denoise_strength, face_enhance, outscale):
83
  """Real-ESRGAN function to restore (and upscale) images."""
84
  global img_mode
85
 
 
108
  'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
109
  ]
110
  else:
111
+ raise HTTPException(status_code=400, detail=f"Invalid model name: {model_name}")
112
 
113
  # Download model if not already available
114
  model_path = os.path.join('weights', model_name + '.pth')
 
173
  else:
174
  output, _ = upsampler.enhance(img, outscale=outscale)
175
  except RuntimeError as error:
176
+ raise HTTPException(status_code=500, detail=f"Processing error: {str(error)}")
177
 
178
  # Convert back to appropriate format based on mode
179
  if img_mode == "RGBA":
 
193
 
194
  return output_img, properties
195
 
196
+ @app.post("/enhancer", response_model=EnhancementResponse, summary="Enhance and upscale an image")
197
+ async def enhance_image(
198
+ image: UploadFile = File(..., description="Image file to enhance"),
199
+ model: str = Form("RealESRGAN_x4plus", description="Model name to use for enhancement"),
200
+ denoise_strength: float = Form(0.5, description="Denoise strength (0-1)"),
201
+ outscale: int = Form(4, description="Output scale factor"),
202
+ face_enhance: bool = Form(False, description="Enable face enhancement")
203
+ ):
204
+ """
205
+ Enhance and upscale an image using Real-ESRGAN models.
206
+
207
+ - **image**: Upload an image file (PNG, JPG, etc.)
208
+ - **model**: Select a model from the available options
209
+ - **denoise_strength**: Control the denoising strength (only for realesr-general-x4v3)
210
+ - **outscale**: Control the output resolution scaling
211
+ - **face_enhance**: Enable face enhancement using GFPGAN
212
+
213
+ Returns the enhanced image as a base64 string along with image properties.
214
+ """
215
  try:
216
+ # Validate model name
217
+ valid_models = [m["name"] for m in AVAILABLE_MODELS]
218
+ if model not in valid_models:
219
+ raise HTTPException(
220
+ status_code=400,
221
+ detail=f"Invalid model. Choose from: {', '.join(valid_models)}"
222
+ )
223
 
224
+ # Validate other parameters
225
+ if not (0 <= denoise_strength <= 1):
226
+ raise HTTPException(status_code=400, detail="Denoise strength must be between 0 and 1")
227
 
228
+ if not (1 <= outscale <= 8):
229
+ raise HTTPException(status_code=400, detail="Outscale must be between 1 and 8")
 
 
 
230
 
231
+ # Read the image file
232
+ contents = await image.read()
233
+ img = Image.open(BytesIO(contents))
234
 
235
  # Process image
236
+ output_img, properties = await process_image(img, model, denoise_strength, face_enhance, outscale)
 
 
 
237
 
238
  # Convert to PIL Image and then to base64
239
  output_pil = Image.fromarray(output_img)
 
249
  img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
250
 
251
  # Return response
252
+ return {
253
  "enhanced_image": img_str,
254
  "properties": properties,
255
+ "model_used": model
256
+ }
257
 
258
+ except HTTPException as e:
259
+ raise e
260
  except Exception as e:
261
+ raise HTTPException(status_code=500, detail=f"Server error: {str(e)}")
262
+
263
+ @app.get("/", response_model=HealthResponse, summary="Check server health")
264
+ async def health_check():
265
+ """Check if the image enhancement server is running."""
266
+ return {"status": "healthy", "message": "Image enhancement server is running"}
267
+
268
+ @app.get("/models", response_model=ModelsResponse, summary="List available models")
269
+ async def list_models():
270
+ """Get a list of all available enhancement models with descriptions."""
271
+ return {"models": AVAILABLE_MODELS}
272
+
273
+ # Add startup event to print server info
274
+ @app.on_event("startup")
275
+ async def startup_event():
276
+ print("🚀 Image Enhancement API is starting up!")
277
+ print(f"📚 Available models: {', '.join(m['name'] for m in AVAILABLE_MODELS)}")
278
+ print("📋 API documentation available at /docs or /redoc")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
 
280
  if __name__ == "__main__":
281
+ # Run server with Uvicorn
282
  port = int(os.environ.get("PORT", 8000))
283
+ uvicorn.run("app:app", host="0.0.0.0", port=port, reload=True)