restoration-api / app.py
AhsanAftab's picture
Update app.py
64e9eb5 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from flask import Flask, request, jsonify
from flask_cors import CORS
import numpy as np
import cv2
import base64
from io import BytesIO
from PIL import Image
app = Flask(__name__)
CORS(app)
DEVICE = torch.device('cpu') # Force CPU for HF Free Tier
MODEL_PATH = "best_model_fixed.pth" # Upload your trained .pth file
class InpaintingGenerator(nn.Module):
def __init__(self, input_channels=4):
super().__init__()
resnet = models.resnet34(weights=None)
self.enc1 = nn.Sequential(
nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3, bias=False),
resnet.bn1, resnet.relu
)
self.enc2 = resnet.layer1
self.enc3 = resnet.layer2
self.enc4 = resnet.layer3
self.enc5 = resnet.layer4
self.bottleneck = nn.Sequential(
nn.Conv2d(512, 512, 3, padding=1), nn.BatchNorm2d(512), nn.ReLU(True),
nn.Conv2d(512, 512, 3, padding=1), nn.BatchNorm2d(512), nn.ReLU(True)
)
self.up1 = self._make_decoder_block(512, 256)
self.up2 = self._make_decoder_block(512, 128) # 256+256
self.up3 = self._make_decoder_block(256, 64) # 128+128
self.up4 = self._make_decoder_block(128, 32) # 64+64
self.texture_refine = nn.Sequential(
nn.Conv2d(32, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(True),
nn.Conv2d(32, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(True)
)
self.final = nn.Sequential(
nn.Conv2d(32, 16, 3, padding=1), nn.ReLU(True),
nn.Conv2d(16, 3, 3, padding=1), nn.Tanh()
)
def _make_decoder_block(self, in_channels, out_channels):
return nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, 4, stride=2, padding=1),
nn.BatchNorm2d(out_channels), nn.ReLU(True),
nn.Conv2d(out_channels, out_channels, 3, padding=1),
nn.BatchNorm2d(out_channels), nn.ReLU(True),
nn.Conv2d(out_channels, out_channels, 3, padding=1),
nn.BatchNorm2d(out_channels), nn.ReLU(True)
)
def forward(self, img, mask):
x = torch.cat([img, mask], dim=1)
x1 = self.enc1(x)
x2 = self.enc2(x1)
x3 = self.enc3(x2)
x4 = self.enc4(x3)
x5 = self.enc5(x4)
x = self.bottleneck(x5)
x = self.up1(x)
x = torch.cat([x, x4], dim=1)
x = self.up2(x)
x = torch.cat([x, x3], dim=1)
x = self.up3(x)
x = torch.cat([x, x2], dim=1)
x = self.up4(x)
x = self.texture_refine(x)
return self.final(x)
print("Loading Inpainting Model...")
model = InpaintingGenerator().to(DEVICE)
try:
# Set weights_only=False to avoid numpy errors
checkpoint = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=False)
# Handle DataParallel wrapping
if 'generator' in checkpoint:
state_dict = checkpoint['generator']
else:
state_dict = checkpoint
new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
model.load_state_dict(new_state_dict, strict=False)
model.eval()
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
def to_base64(image_array):
img = Image.fromarray(image_array)
buffer = BytesIO()
img.save(buffer, format="PNG")
return base64.b64encode(buffer.getvalue()).decode('utf-8')
@app.route('/')
def home():
return "Inpainting API is Running!"
@app.route('/inpaint', methods=['POST'])
def inpaint():
if 'image' not in request.files or 'mask' not in request.files:
return jsonify({'error': 'Please upload both image and mask'}), 400
try:
# 1. Read Image
img_file = request.files['image']
img_arr = np.frombuffer(img_file.read(), np.uint8)
img_cv = cv2.imdecode(img_arr, cv2.IMREAD_COLOR)
img_cv = cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB)
# 2. Read Mask
mask_file = request.files['mask']
mask_arr = np.frombuffer(mask_file.read(), np.uint8)
# [CRITICAL FIX] Read "unchanged" to preserve the low values (1, 2, 3)
mask_cv = cv2.imdecode(mask_arr, cv2.IMREAD_UNCHANGED)
# If mask is RGB/RGBA, convert to grayscale
if len(mask_cv.shape) > 2:
mask_cv = cv2.cvtColor(mask_cv, cv2.COLOR_BGR2GRAY)
# 3. Preprocess
img_h, img_w = img_cv.shape[:2]
img_resized = cv2.resize(img_cv, (512, 512))
# Resize mask carefully (Nearest Neighbor preserves exact class IDs 0,1,2...)
mask_resized = cv2.resize(mask_cv, (512, 512), interpolation=cv2.INTER_NEAREST)
# Normalize Image
img_tensor = (torch.tensor(img_resized).float() / 127.5) - 1.0
img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0).to(DEVICE)
# [CRITICAL FIX] Logic change: Check if pixel > 0, NOT > 127
# This converts your class indices (1, 2, 3...) into a binary 1.0
mask_tensor = (torch.tensor(mask_resized).float() > 0).float().unsqueeze(0).unsqueeze(0).to(DEVICE)
# 4. Inference
with torch.no_grad():
output = model(img_tensor, mask_tensor)
# 5. Post-process (Same as before)
output_np = output.squeeze().permute(1, 2, 0).cpu().numpy()
output_np = (output_np + 1.0) * 127.5
output_np = np.clip(output_np, 0, 255).astype(np.uint8)
output_final = cv2.resize(output_np, (img_w, img_h))
return jsonify({'result': f"data:image/png;base64,{to_base64(output_final)}"})
except Exception as e:
return jsonify({'error': str(e)}), 500
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860)