File size: 13,106 Bytes
41e2ade
 
 
 
 
 
f59a1e6
39cf1a9
 
f59a1e6
 
 
41e2ade
 
 
 
 
f59a1e6
 
 
 
 
 
41e2ade
39cf1a9
 
 
 
 
 
 
 
 
41e2ade
 
 
 
 
 
f59a1e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41e2ade
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f59a1e6
41e2ade
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f59a1e6
41e2ade
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39cf1a9
 
 
 
 
 
f59a1e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41e2ade
f59a1e6
 
 
 
 
 
 
41e2ade
f59a1e6
 
 
41e2ade
f59a1e6
 
41e2ade
f59a1e6
 
 
41e2ade
 
f59a1e6
41e2ade
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f59a1e6
41e2ade
 
f59a1e6
 
41e2ade
f59a1e6
 
41e2ade
f59a1e6
 
39cf1a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f59a1e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41e2ade
 
39cf1a9
 
f59a1e6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
import os
import cv2
import numpy
import base64
from io import BytesIO
from PIL import Image
from fastapi import FastAPI, File, Form, UploadFile, HTTPException
from fastapi.responses import JSONResponse, Response
from fastapi.middleware.cors import CORSMiddleware
from typing import Optional, List, Dict, Any, Union
from pydantic import BaseModel
import uvicorn
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact

# Create FastAPI app
app = FastAPI(
    title="Image Enhancement API",
    description="API for enhancing and upscaling images using Real-ESRGAN models",
    version="1.0.0"
)

# Add CORS middleware for embedding in other websites
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # For production, you may want to restrict this
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Create weights directory if it doesn't exist
os.makedirs('weights', exist_ok=True)

# Global variable to track image mode
img_mode = "RGBA"

# Models information
AVAILABLE_MODELS = [
    {
        "name": "RealESRGAN_x4plus",
        "description": "General purpose 4x upscaling model",
        "scale": 4
    },
    {
        "name": "RealESRNet_x4plus",
        "description": "Alternative 4x upscaling model",
        "scale": 4
    },
    {
        "name": "RealESRGAN_x4plus_anime_6B",
        "description": "Specialized for anime/cartoon images, 4x upscaling",
        "scale": 4
    },
    {
        "name": "RealESRGAN_x2plus",
        "description": "2x upscaling model",
        "scale": 2
    },
    {
        "name": "realesr-general-x4v3",
        "description": "General purpose 4x upscaling model with denoise control",
        "scale": 4
    }
]

# Pydantic models for API documentation
class HealthResponse(BaseModel):
    status: str
    message: str

class ModelInfo(BaseModel):
    name: str
    description: str
    scale: int

class ModelsResponse(BaseModel):
    models: List[ModelInfo]

class ImageProperties(BaseModel):
    width: int
    height: int
    mode: str

class EnhancementResponse(BaseModel):
    enhanced_image: str
    properties: ImageProperties
    model_used: str

async def process_image(img_data, model_name, denoise_strength, face_enhance, outscale):
    """Real-ESRGAN function to restore (and upscale) images."""
    global img_mode
    
    # Define model parameters
    if model_name == 'RealESRGAN_x4plus':
        model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
        netscale = 4
        file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
    elif model_name == 'RealESRNet_x4plus':
        model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
        netscale = 4
        file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth']
    elif model_name == 'RealESRGAN_x4plus_anime_6B':
        model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
        netscale = 4
        file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth']
    elif model_name == 'RealESRGAN_x2plus':
        model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
        netscale = 2
        file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth']
    elif model_name == 'realesr-general-x4v3':
        model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
        netscale = 4
        file_url = [
            'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth',
            'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
        ]
    else:
        raise HTTPException(status_code=400, detail=f"Invalid model name: {model_name}")

    # Download model if not already available
    model_path = os.path.join('weights', model_name + '.pth')
    if not os.path.isfile(model_path):
        ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
        for url in file_url:
            model_path = load_file_from_url(
                url=url, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None)

    # Handle denoise strength for realesr-general-x4v3
    dni_weight = None
    if model_name == 'realesr-general-x4v3' and denoise_strength != 1:
        wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3')
        model_path = [model_path, wdn_model_path]
        dni_weight = [denoise_strength, 1 - denoise_strength]

    # Initialize upsampler
    upsampler = RealESRGANer(
        scale=netscale,
        model_path=model_path,
        dni_weight=dni_weight,
        model=model,
        tile=0,
        tile_pad=10,
        pre_pad=10,
        half=False,
        gpu_id=None
    )

    # Initialize face enhancer if needed
    if face_enhance:
        from gfpgan import GFPGANer
        face_enhancer = GFPGANer(
            model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
            upscale=outscale,
            arch='clean',
            channel_multiplier=2,
            bg_upsampler=upsampler)

    # Convert input image to CV2 format
    if isinstance(img_data, Image.Image):
        # Convert PIL Image to numpy array
        img_array = numpy.array(img_data)
        if img_data.mode == "RGBA":
            img_mode = "RGBA"
            img = cv2.cvtColor(img_array, cv2.COLOR_RGBA2BGRA)
        else:
            img_mode = "RGB"
            img = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
    else:
        # Already a numpy array
        img = img_data
        if img.shape[2] == 4:
            img_mode = "RGBA"
        else:
            img_mode = "RGB"

    try:
        # Process image
        if face_enhance:
            _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
        else:
            output, _ = upsampler.enhance(img, outscale=outscale)
    except RuntimeError as error:
        raise HTTPException(status_code=500, detail=f"Processing error: {str(error)}")
    
    # Convert back to appropriate format based on mode
    if img_mode == "RGBA":
        output_img = cv2.cvtColor(output, cv2.COLOR_BGRA2RGBA)
    else:
        output_img = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
    
    # Get image properties
    height, width = output_img.shape[:2]
    channels = output_img.shape[2] if len(output_img.shape) > 2 else 1
    
    properties = {
        "width": width,
        "height": height,
        "mode": "RGBA" if channels == 4 else "RGB" if channels == 3 else "Grayscale"
    }
    
    return output_img, properties

# Root endpoint for health check - important for Spaces
@app.get("/", response_model=HealthResponse)
async def read_root():
    """Check if the image enhancement API is running."""
    return {"status": "ok", "message": "Image Enhancement API is running"}

@app.post("/enhancer", response_model=EnhancementResponse, summary="Enhance and upscale an image")
async def enhance_image(
    image: UploadFile = File(..., description="Image file to enhance"),
    model: str = Form("RealESRGAN_x4plus", description="Model name to use for enhancement"),
    denoise_strength: float = Form(0.5, description="Denoise strength (0-1)"),
    outscale: int = Form(4, description="Output scale factor"),
    face_enhance: bool = Form(False, description="Enable face enhancement")
):
    """
    Enhance and upscale an image using Real-ESRGAN models.
    
    - **image**: Upload an image file (PNG, JPG, etc.)
    - **model**: Select a model from the available options
    - **denoise_strength**: Control the denoising strength (only for realesr-general-x4v3)
    - **outscale**: Control the output resolution scaling
    - **face_enhance**: Enable face enhancement using GFPGAN
    
    Returns the enhanced image as a base64 string along with image properties.
    """
    try:
        # Validate model name
        valid_models = [m["name"] for m in AVAILABLE_MODELS]
        if model not in valid_models:
            raise HTTPException(
                status_code=400, 
                detail=f"Invalid model. Choose from: {', '.join(valid_models)}"
            )
        
        # Validate other parameters
        if not (0 <= denoise_strength <= 1):
            raise HTTPException(status_code=400, detail="Denoise strength must be between 0 and 1")
        
        if not (1 <= outscale <= 8):
            raise HTTPException(status_code=400, detail="Outscale must be between 1 and 8")
        
        # Read the image file
        contents = await image.read()
        img = Image.open(BytesIO(contents))
        
        # Process image
        output_img, properties = await process_image(img, model, denoise_strength, face_enhance, outscale)
        
        # Convert to PIL Image and then to base64
        output_pil = Image.fromarray(output_img)
        
        # Save to buffer
        buffer = BytesIO()
        if properties["mode"] == "RGBA":
            output_pil.save(buffer, format="PNG")
        else:
            output_pil.save(buffer, format="JPEG", quality=95)
        
        # Encode to base64
        img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
        
        # Return response
        return {
            "enhanced_image": img_str,
            "properties": properties,
            "model_used": model
        }
    
    except HTTPException as e:
        raise e
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Server error: {str(e)}")

# Add a direct image endpoint that returns the actual image instead of base64
@app.post("/enhancer/image", summary="Enhance and return the image directly")
async def enhance_image_direct(
    image: UploadFile = File(..., description="Image file to enhance"),
    model: str = Form("RealESRGAN_x4plus", description="Model name to use for enhancement"),
    denoise_strength: float = Form(0.5, description="Denoise strength (0-1)"),
    outscale: int = Form(4, description="Output scale factor"),
    face_enhance: bool = Form(False, description="Enable face enhancement")
):
    """
    Enhance and upscale an image, returning the actual image file directly.
    
    This endpoint works like /enhancer but returns the image directly instead of base64 encoded.
    This is useful for direct image display or download.
    """
    try:
        # Validate model name
        valid_models = [m["name"] for m in AVAILABLE_MODELS]
        if model not in valid_models:
            raise HTTPException(
                status_code=400, 
                detail=f"Invalid model. Choose from: {', '.join(valid_models)}"
            )
        
        # Read the image file
        contents = await image.read()
        img = Image.open(BytesIO(contents))
        
        # Process image
        output_img, properties = await process_image(img, model, denoise_strength, face_enhance, outscale)
        
        # Convert to PIL Image
        output_pil = Image.fromarray(output_img)
        
        # Save to buffer
        buffer = BytesIO()
        image_format = "PNG" if properties["mode"] == "RGBA" else "JPEG"
        
        if image_format == "PNG":
            output_pil.save(buffer, format="PNG")
            media_type = "image/png"
        else:
            output_pil.save(buffer, format="JPEG", quality=95)
            media_type = "image/jpeg"
        
        buffer.seek(0)
        
        # Return the image directly
        return Response(content=buffer.getvalue(), media_type=media_type)
    
    except HTTPException as e:
        raise e
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Server error: {str(e)}")

@app.get("/health", response_model=HealthResponse, summary="Check server health")
async def health_check():
    """Check if the image enhancement server is running."""
    return {"status": "healthy", "message": "Image enhancement server is running"}

@app.get("/models", response_model=ModelsResponse, summary="List available models")
async def list_models():
    """Get a list of all available enhancement models with descriptions."""
    return {"models": AVAILABLE_MODELS}

# Add startup event to print server info
@app.on_event("startup")
async def startup_event():
    print("๐Ÿš€ Image Enhancement API is starting up!")
    print(f"๐Ÿ“š Available models: {', '.join(m['name'] for m in AVAILABLE_MODELS)}")
    print("๐Ÿ“‹ API documentation available at /docs or /redoc")

if __name__ == "__main__":
    # Run server with Uvicorn on port 7860 for Hugging Face Spaces
    port = int(os.environ.get("PORT", 7860))
    uvicorn.run("app:app", host="0.0.0.0", port=port, reload=True)