VAEbackend / server.py
kiran kumar
Initial Hugging face push
5d988d2
from flask import Flask, jsonify, request, send_file
from flask_cors import CORS
from PIL import Image
import io
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.nn.functional as F
import os
import time
from collections import OrderedDict
app = Flask(__name__)
CORS(app, origins='*')
latent_dim = 50
@app.route("/healthz")
def health():
return "OK", 200
@app.route("/api/users", methods=['GET'])
def users():
return jsonify({
"users": [
'kiran',
'kumar',
'kanathala',
]
})
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
transform = transforms.Compose([
transforms.Resize(64), # Resize all images to 64x64
transforms.CenterCrop(64), # Crop center square
transforms.ToTensor() # Convert to tensor
])
class Encoder(nn.Module):
def __init__(self, latent_dim):
super(Encoder, self).__init__()
self.conv_layers = nn.Sequential(
nn.Conv2d(3, 16, 3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(16, 32, 3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, 3, stride=2, padding=1),
nn.ReLU(),
nn.Flatten()
)
self.fc1 = nn.Linear(64*8*8, 128)
self.fc_mu = nn.Linear(128, latent_dim)
self.fc_logvar = nn.Linear(128, latent_dim)
def forward(self, x):
h = self.conv_layers(x)
h = F.relu(self.fc1(h))
return self.fc_mu(h), self.fc_logvar(h)
class Decoder(nn.Module):
def __init__(self, latent_dim):
super(Decoder, self).__init__()
self.fc = nn.Sequential(
nn.Linear(latent_dim, 128),
nn.ReLU(),
nn.Linear(128, 64 * 8 * 8),
nn.ReLU(),
nn.Unflatten(1, (64, 8, 8))
)
self.deconv = nn.Sequential(
nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
nn.ReLU(),
nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1),
nn.ReLU(),
nn.ConvTranspose2d(16, 3, 3, stride=2, padding=1, output_padding=1),
nn.Sigmoid()
)
def forward(self, z):
h = self.fc(z)
return self.deconv(h)
class VAE(nn.Module):
def __init__(self, latent_dim):
super(VAE, self).__init__()
self.encoder = Encoder(latent_dim)
self.decoder = Decoder(latent_dim)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, x):
mu, logvar = self.encoder(x)
z = self.reparameterize(mu, logvar)
return self.decoder(z), mu , logvar
model = VAE(latent_dim).to(device)
state_dict = torch.load("vae_ddp.pth", map_location=device)
new_state_dict = OrderedDict()
for k, v in state_dict.items():
new_state_dict[k.replace("module.", "")] = v
model.load_state_dict(new_state_dict)
model.eval()
print("Model loaded successfully!")
@app.route("/reconstruct", methods=["POST"])
def reconstruct():
if "image" not in request.files:
return "No image uploaded", 400
file = request.files["image"]
img = Image.open(file.stream).convert("RGB")
orig_size = img.size
# Transform and send through autoencoder
img_tensor = transform(img).unsqueeze(0).to(device)
with torch.no_grad():
decoded, mu, logvar = model(img_tensor)
# Convert back to PIL
recon_img = decoded.squeeze(0).cpu() # remove batch dimension
recon_img = transforms.ToPILImage()(recon_img)
recon_img = recon_img.resize(orig_size)
# Send image as BytesIO
buf = io.BytesIO()
recon_img.save(buf, format="PNG")
buf.seek(0)
return send_file(buf, mimetype="image/png", as_attachment=False, download_name="reconstructed.png")
if __name__ == "__main__":
app.run()