from flask import Flask, jsonify, request, send_file from flask_cors import CORS from PIL import Image import io import torch import torch.nn as nn import torchvision.transforms as transforms import torch.nn.functional as F import os import time from collections import OrderedDict app = Flask(__name__) CORS(app, origins='*') latent_dim = 50 @app.route("/healthz") def health(): return "OK", 200 @app.route("/api/users", methods=['GET']) def users(): return jsonify({ "users": [ 'kiran', 'kumar', 'kanathala', ] }) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Using device:", device) transform = transforms.Compose([ transforms.Resize(64), # Resize all images to 64x64 transforms.CenterCrop(64), # Crop center square transforms.ToTensor() # Convert to tensor ]) class Encoder(nn.Module): def __init__(self, latent_dim): super(Encoder, self).__init__() self.conv_layers = nn.Sequential( nn.Conv2d(3, 16, 3, stride=2, padding=1), nn.ReLU(), nn.Conv2d(16, 32, 3, stride=2, padding=1), nn.ReLU(), nn.Conv2d(32, 64, 3, stride=2, padding=1), nn.ReLU(), nn.Flatten() ) self.fc1 = nn.Linear(64*8*8, 128) self.fc_mu = nn.Linear(128, latent_dim) self.fc_logvar = nn.Linear(128, latent_dim) def forward(self, x): h = self.conv_layers(x) h = F.relu(self.fc1(h)) return self.fc_mu(h), self.fc_logvar(h) class Decoder(nn.Module): def __init__(self, latent_dim): super(Decoder, self).__init__() self.fc = nn.Sequential( nn.Linear(latent_dim, 128), nn.ReLU(), nn.Linear(128, 64 * 8 * 8), nn.ReLU(), nn.Unflatten(1, (64, 8, 8)) ) self.deconv = nn.Sequential( nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1), nn.ReLU(), nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1), nn.ReLU(), nn.ConvTranspose2d(16, 3, 3, stride=2, padding=1, output_padding=1), nn.Sigmoid() ) def forward(self, z): h = self.fc(z) return self.deconv(h) class VAE(nn.Module): def __init__(self, latent_dim): super(VAE, self).__init__() self.encoder = Encoder(latent_dim) self.decoder = Decoder(latent_dim) def reparameterize(self, mu, logvar): std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return mu + eps * std def forward(self, x): mu, logvar = self.encoder(x) z = self.reparameterize(mu, logvar) return self.decoder(z), mu , logvar model = VAE(latent_dim).to(device) state_dict = torch.load("vae_ddp.pth", map_location=device) new_state_dict = OrderedDict() for k, v in state_dict.items(): new_state_dict[k.replace("module.", "")] = v model.load_state_dict(new_state_dict) model.eval() print("Model loaded successfully!") @app.route("/reconstruct", methods=["POST"]) def reconstruct(): if "image" not in request.files: return "No image uploaded", 400 file = request.files["image"] img = Image.open(file.stream).convert("RGB") orig_size = img.size # Transform and send through autoencoder img_tensor = transform(img).unsqueeze(0).to(device) with torch.no_grad(): decoded, mu, logvar = model(img_tensor) # Convert back to PIL recon_img = decoded.squeeze(0).cpu() # remove batch dimension recon_img = transforms.ToPILImage()(recon_img) recon_img = recon_img.resize(orig_size) # Send image as BytesIO buf = io.BytesIO() recon_img.save(buf, format="PNG") buf.seek(0) return send_file(buf, mimetype="image/png", as_attachment=False, download_name="reconstructed.png") if __name__ == "__main__": app.run()