File size: 3,003 Bytes
cd698c9
 
 
 
 
 
 
 
 
 
 
999e0c3
cd698c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from flask import Flask, jsonify, request, send_file
from flask_cors import CORS
from PIL import Image
import io
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import os
import time

app = Flask(__name__)
CORS(app, origins= ["https://autoencoderkiranfrontend.vercel.app/"])

@app.route("/healthz")
def health():
    return "OK", 200


@app.route("/api/users", methods=['GET'])
def users():
    return jsonify({
        "users": [
            'kiran',
            'kumar',
            'kanathala',
        ]
    })

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ----------------------------
# 2️⃣ Transform (same as training)
# ----------------------------
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor()
])

class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3,32,3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32,64,3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, 3, stride=2, padding=1)
        )

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x
# ----------------------------
# 4️⃣ Load trained model
# ----------------------------
model = Autoencoder().to(device)
model.load_state_dict(torch.load("celeba_autoencoder.pth", map_location=device))
model.eval()
print("Model loaded successfully!")

@app.route("/reconstruct", methods=["POST"])
def reconstruct():
    if "image" not in request.files:
        return "No image uploaded", 400
    
    file = request.files["image"]
    img = Image.open(file.stream).convert("RGB")
    orig_size = img.size

    # Transform and send through autoencoder
    img_tensor = transform(img).unsqueeze(0).to(device)
    with torch.no_grad():
        reconstructed = model(img_tensor)

    # Convert back to PIL
    recon_img = reconstructed.squeeze(0).cpu()  # remove batch dimension
    recon_img = transforms.ToPILImage()(recon_img)
    recon_img = recon_img.resize(orig_size)

    # Send image as BytesIO without saving to disk
    buf = io.BytesIO()
    recon_img.save(buf, format="PNG")
    buf.seek(0)

    return send_file(buf, mimetype="image/png", as_attachment=False, download_name="reconstructed.png")

if __name__ == "__main__":
    app.run()