Spaces:
Running
Running
| from flask import Flask, jsonify, request, send_file | |
| from flask_cors import CORS | |
| from PIL import Image | |
| import io | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.transforms as transforms | |
| import torch.nn.functional as F | |
| import os | |
| import time | |
| from collections import OrderedDict | |
| app = Flask(__name__) | |
| CORS(app, origins='*') | |
| latent_dim = 50 | |
| def health(): | |
| return "OK", 200 | |
| def users(): | |
| return jsonify({ | |
| "users": [ | |
| 'kiran', | |
| 'kumar', | |
| 'kanathala', | |
| ] | |
| }) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print("Using device:", device) | |
| transform = transforms.Compose([ | |
| transforms.Resize(64), # Resize all images to 64x64 | |
| transforms.CenterCrop(64), # Crop center square | |
| transforms.ToTensor() # Convert to tensor | |
| ]) | |
| class Encoder(nn.Module): | |
| def __init__(self, latent_dim): | |
| super(Encoder, self).__init__() | |
| self.conv_layers = nn.Sequential( | |
| nn.Conv2d(3, 16, 3, stride=2, padding=1), | |
| nn.ReLU(), | |
| nn.Conv2d(16, 32, 3, stride=2, padding=1), | |
| nn.ReLU(), | |
| nn.Conv2d(32, 64, 3, stride=2, padding=1), | |
| nn.ReLU(), | |
| nn.Flatten() | |
| ) | |
| self.fc1 = nn.Linear(64*8*8, 128) | |
| self.fc_mu = nn.Linear(128, latent_dim) | |
| self.fc_logvar = nn.Linear(128, latent_dim) | |
| def forward(self, x): | |
| h = self.conv_layers(x) | |
| h = F.relu(self.fc1(h)) | |
| return self.fc_mu(h), self.fc_logvar(h) | |
| class Decoder(nn.Module): | |
| def __init__(self, latent_dim): | |
| super(Decoder, self).__init__() | |
| self.fc = nn.Sequential( | |
| nn.Linear(latent_dim, 128), | |
| nn.ReLU(), | |
| nn.Linear(128, 64 * 8 * 8), | |
| nn.ReLU(), | |
| nn.Unflatten(1, (64, 8, 8)) | |
| ) | |
| self.deconv = nn.Sequential( | |
| nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1), | |
| nn.ReLU(), | |
| nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1), | |
| nn.ReLU(), | |
| nn.ConvTranspose2d(16, 3, 3, stride=2, padding=1, output_padding=1), | |
| nn.Sigmoid() | |
| ) | |
| def forward(self, z): | |
| h = self.fc(z) | |
| return self.deconv(h) | |
| class VAE(nn.Module): | |
| def __init__(self, latent_dim): | |
| super(VAE, self).__init__() | |
| self.encoder = Encoder(latent_dim) | |
| self.decoder = Decoder(latent_dim) | |
| def reparameterize(self, mu, logvar): | |
| std = torch.exp(0.5 * logvar) | |
| eps = torch.randn_like(std) | |
| return mu + eps * std | |
| def forward(self, x): | |
| mu, logvar = self.encoder(x) | |
| z = self.reparameterize(mu, logvar) | |
| return self.decoder(z), mu , logvar | |
| model = VAE(latent_dim).to(device) | |
| state_dict = torch.load("vae_ddp.pth", map_location=device) | |
| new_state_dict = OrderedDict() | |
| for k, v in state_dict.items(): | |
| new_state_dict[k.replace("module.", "")] = v | |
| model.load_state_dict(new_state_dict) | |
| model.eval() | |
| print("Model loaded successfully!") | |
| def reconstruct(): | |
| if "image" not in request.files: | |
| return "No image uploaded", 400 | |
| file = request.files["image"] | |
| img = Image.open(file.stream).convert("RGB") | |
| orig_size = img.size | |
| # Transform and send through autoencoder | |
| img_tensor = transform(img).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| decoded, mu, logvar = model(img_tensor) | |
| # Convert back to PIL | |
| recon_img = decoded.squeeze(0).cpu() # remove batch dimension | |
| recon_img = transforms.ToPILImage()(recon_img) | |
| recon_img = recon_img.resize(orig_size) | |
| # Send image as BytesIO | |
| buf = io.BytesIO() | |
| recon_img.save(buf, format="PNG") | |
| buf.seek(0) | |
| return send_file(buf, mimetype="image/png", as_attachment=False, download_name="reconstructed.png") | |
| if __name__ == "__main__": | |
| app.run() | |