import torch import torch.nn as nn import torch.nn.functional as F import torchvision.models as models from flask import Flask, request, jsonify from flask_cors import CORS import numpy as np import cv2 import base64 from io import BytesIO from PIL import Image app = Flask(__name__) CORS(app) DEVICE = torch.device('cpu') # Force CPU for HF Free Tier MODEL_PATH = "best_model_fixed.pth" # Upload your trained .pth file class InpaintingGenerator(nn.Module): def __init__(self, input_channels=4): super().__init__() resnet = models.resnet34(weights=None) self.enc1 = nn.Sequential( nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3, bias=False), resnet.bn1, resnet.relu ) self.enc2 = resnet.layer1 self.enc3 = resnet.layer2 self.enc4 = resnet.layer3 self.enc5 = resnet.layer4 self.bottleneck = nn.Sequential( nn.Conv2d(512, 512, 3, padding=1), nn.BatchNorm2d(512), nn.ReLU(True), nn.Conv2d(512, 512, 3, padding=1), nn.BatchNorm2d(512), nn.ReLU(True) ) self.up1 = self._make_decoder_block(512, 256) self.up2 = self._make_decoder_block(512, 128) # 256+256 self.up3 = self._make_decoder_block(256, 64) # 128+128 self.up4 = self._make_decoder_block(128, 32) # 64+64 self.texture_refine = nn.Sequential( nn.Conv2d(32, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(True), nn.Conv2d(32, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(True) ) self.final = nn.Sequential( nn.Conv2d(32, 16, 3, padding=1), nn.ReLU(True), nn.Conv2d(16, 3, 3, padding=1), nn.Tanh() ) def _make_decoder_block(self, in_channels, out_channels): return nn.Sequential( nn.ConvTranspose2d(in_channels, out_channels, 4, stride=2, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(True), nn.Conv2d(out_channels, out_channels, 3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(True), nn.Conv2d(out_channels, out_channels, 3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(True) ) def forward(self, img, mask): x = torch.cat([img, mask], dim=1) x1 = self.enc1(x) x2 = self.enc2(x1) x3 = self.enc3(x2) x4 = self.enc4(x3) x5 = self.enc5(x4) x = self.bottleneck(x5) x = self.up1(x) x = torch.cat([x, x4], dim=1) x = self.up2(x) x = torch.cat([x, x3], dim=1) x = self.up3(x) x = torch.cat([x, x2], dim=1) x = self.up4(x) x = self.texture_refine(x) return self.final(x) print("Loading Inpainting Model...") model = InpaintingGenerator().to(DEVICE) try: # Set weights_only=False to avoid numpy errors checkpoint = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=False) # Handle DataParallel wrapping if 'generator' in checkpoint: state_dict = checkpoint['generator'] else: state_dict = checkpoint new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} model.load_state_dict(new_state_dict, strict=False) model.eval() print("Model loaded successfully!") except Exception as e: print(f"Error loading model: {e}") def to_base64(image_array): img = Image.fromarray(image_array) buffer = BytesIO() img.save(buffer, format="PNG") return base64.b64encode(buffer.getvalue()).decode('utf-8') @app.route('/') def home(): return "Inpainting API is Running!" @app.route('/inpaint', methods=['POST']) def inpaint(): if 'image' not in request.files or 'mask' not in request.files: return jsonify({'error': 'Please upload both image and mask'}), 400 try: # 1. Read Image img_file = request.files['image'] img_arr = np.frombuffer(img_file.read(), np.uint8) img_cv = cv2.imdecode(img_arr, cv2.IMREAD_COLOR) img_cv = cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB) # 2. Read Mask mask_file = request.files['mask'] mask_arr = np.frombuffer(mask_file.read(), np.uint8) # [CRITICAL FIX] Read "unchanged" to preserve the low values (1, 2, 3) mask_cv = cv2.imdecode(mask_arr, cv2.IMREAD_UNCHANGED) # If mask is RGB/RGBA, convert to grayscale if len(mask_cv.shape) > 2: mask_cv = cv2.cvtColor(mask_cv, cv2.COLOR_BGR2GRAY) # 3. Preprocess img_h, img_w = img_cv.shape[:2] img_resized = cv2.resize(img_cv, (512, 512)) # Resize mask carefully (Nearest Neighbor preserves exact class IDs 0,1,2...) mask_resized = cv2.resize(mask_cv, (512, 512), interpolation=cv2.INTER_NEAREST) # Normalize Image img_tensor = (torch.tensor(img_resized).float() / 127.5) - 1.0 img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0).to(DEVICE) # [CRITICAL FIX] Logic change: Check if pixel > 0, NOT > 127 # This converts your class indices (1, 2, 3...) into a binary 1.0 mask_tensor = (torch.tensor(mask_resized).float() > 0).float().unsqueeze(0).unsqueeze(0).to(DEVICE) # 4. Inference with torch.no_grad(): output = model(img_tensor, mask_tensor) # 5. Post-process (Same as before) output_np = output.squeeze().permute(1, 2, 0).cpu().numpy() output_np = (output_np + 1.0) * 127.5 output_np = np.clip(output_np, 0, 255).astype(np.uint8) output_final = cv2.resize(output_np, (img_w, img_h)) return jsonify({'result': f"data:image/png;base64,{to_base64(output_final)}"}) except Exception as e: return jsonify({'error': str(e)}), 500 if __name__ == '__main__': app.run(host='0.0.0.0', port=7860)