Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchvision.models as models | |
| from flask import Flask, request, jsonify | |
| from flask_cors import CORS | |
| import numpy as np | |
| import cv2 | |
| import base64 | |
| from io import BytesIO | |
| from PIL import Image | |
| app = Flask(__name__) | |
| CORS(app) | |
| DEVICE = torch.device('cpu') # Force CPU for HF Free Tier | |
| MODEL_PATH = "best_model_fixed.pth" # Upload your trained .pth file | |
| class InpaintingGenerator(nn.Module): | |
| def __init__(self, input_channels=4): | |
| super().__init__() | |
| resnet = models.resnet34(weights=None) | |
| self.enc1 = nn.Sequential( | |
| nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3, bias=False), | |
| resnet.bn1, resnet.relu | |
| ) | |
| self.enc2 = resnet.layer1 | |
| self.enc3 = resnet.layer2 | |
| self.enc4 = resnet.layer3 | |
| self.enc5 = resnet.layer4 | |
| self.bottleneck = nn.Sequential( | |
| nn.Conv2d(512, 512, 3, padding=1), nn.BatchNorm2d(512), nn.ReLU(True), | |
| nn.Conv2d(512, 512, 3, padding=1), nn.BatchNorm2d(512), nn.ReLU(True) | |
| ) | |
| self.up1 = self._make_decoder_block(512, 256) | |
| self.up2 = self._make_decoder_block(512, 128) # 256+256 | |
| self.up3 = self._make_decoder_block(256, 64) # 128+128 | |
| self.up4 = self._make_decoder_block(128, 32) # 64+64 | |
| self.texture_refine = nn.Sequential( | |
| nn.Conv2d(32, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(True), | |
| nn.Conv2d(32, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(True) | |
| ) | |
| self.final = nn.Sequential( | |
| nn.Conv2d(32, 16, 3, padding=1), nn.ReLU(True), | |
| nn.Conv2d(16, 3, 3, padding=1), nn.Tanh() | |
| ) | |
| def _make_decoder_block(self, in_channels, out_channels): | |
| return nn.Sequential( | |
| nn.ConvTranspose2d(in_channels, out_channels, 4, stride=2, padding=1), | |
| nn.BatchNorm2d(out_channels), nn.ReLU(True), | |
| nn.Conv2d(out_channels, out_channels, 3, padding=1), | |
| nn.BatchNorm2d(out_channels), nn.ReLU(True), | |
| nn.Conv2d(out_channels, out_channels, 3, padding=1), | |
| nn.BatchNorm2d(out_channels), nn.ReLU(True) | |
| ) | |
| def forward(self, img, mask): | |
| x = torch.cat([img, mask], dim=1) | |
| x1 = self.enc1(x) | |
| x2 = self.enc2(x1) | |
| x3 = self.enc3(x2) | |
| x4 = self.enc4(x3) | |
| x5 = self.enc5(x4) | |
| x = self.bottleneck(x5) | |
| x = self.up1(x) | |
| x = torch.cat([x, x4], dim=1) | |
| x = self.up2(x) | |
| x = torch.cat([x, x3], dim=1) | |
| x = self.up3(x) | |
| x = torch.cat([x, x2], dim=1) | |
| x = self.up4(x) | |
| x = self.texture_refine(x) | |
| return self.final(x) | |
| print("Loading Inpainting Model...") | |
| model = InpaintingGenerator().to(DEVICE) | |
| try: | |
| # Set weights_only=False to avoid numpy errors | |
| checkpoint = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=False) | |
| # Handle DataParallel wrapping | |
| if 'generator' in checkpoint: | |
| state_dict = checkpoint['generator'] | |
| else: | |
| state_dict = checkpoint | |
| new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} | |
| model.load_state_dict(new_state_dict, strict=False) | |
| model.eval() | |
| print("Model loaded successfully!") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| def to_base64(image_array): | |
| img = Image.fromarray(image_array) | |
| buffer = BytesIO() | |
| img.save(buffer, format="PNG") | |
| return base64.b64encode(buffer.getvalue()).decode('utf-8') | |
| def home(): | |
| return "Inpainting API is Running!" | |
| def inpaint(): | |
| if 'image' not in request.files or 'mask' not in request.files: | |
| return jsonify({'error': 'Please upload both image and mask'}), 400 | |
| try: | |
| # 1. Read Image | |
| img_file = request.files['image'] | |
| img_arr = np.frombuffer(img_file.read(), np.uint8) | |
| img_cv = cv2.imdecode(img_arr, cv2.IMREAD_COLOR) | |
| img_cv = cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB) | |
| # 2. Read Mask | |
| mask_file = request.files['mask'] | |
| mask_arr = np.frombuffer(mask_file.read(), np.uint8) | |
| # [CRITICAL FIX] Read "unchanged" to preserve the low values (1, 2, 3) | |
| mask_cv = cv2.imdecode(mask_arr, cv2.IMREAD_UNCHANGED) | |
| # If mask is RGB/RGBA, convert to grayscale | |
| if len(mask_cv.shape) > 2: | |
| mask_cv = cv2.cvtColor(mask_cv, cv2.COLOR_BGR2GRAY) | |
| # 3. Preprocess | |
| img_h, img_w = img_cv.shape[:2] | |
| img_resized = cv2.resize(img_cv, (512, 512)) | |
| # Resize mask carefully (Nearest Neighbor preserves exact class IDs 0,1,2...) | |
| mask_resized = cv2.resize(mask_cv, (512, 512), interpolation=cv2.INTER_NEAREST) | |
| # Normalize Image | |
| img_tensor = (torch.tensor(img_resized).float() / 127.5) - 1.0 | |
| img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0).to(DEVICE) | |
| # [CRITICAL FIX] Logic change: Check if pixel > 0, NOT > 127 | |
| # This converts your class indices (1, 2, 3...) into a binary 1.0 | |
| mask_tensor = (torch.tensor(mask_resized).float() > 0).float().unsqueeze(0).unsqueeze(0).to(DEVICE) | |
| # 4. Inference | |
| with torch.no_grad(): | |
| output = model(img_tensor, mask_tensor) | |
| # 5. Post-process (Same as before) | |
| output_np = output.squeeze().permute(1, 2, 0).cpu().numpy() | |
| output_np = (output_np + 1.0) * 127.5 | |
| output_np = np.clip(output_np, 0, 255).astype(np.uint8) | |
| output_final = cv2.resize(output_np, (img_w, img_h)) | |
| return jsonify({'result': f"data:image/png;base64,{to_base64(output_final)}"}) | |
| except Exception as e: | |
| return jsonify({'error': str(e)}), 500 | |
| if __name__ == '__main__': | |
| app.run(host='0.0.0.0', port=7860) |