File size: 4,055 Bytes
e1aa346
 
 
 
 
 
 
 
 
b172d7f
e1aa346
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2aac227
b172d7f
 
 
 
 
 
e1aa346
b172d7f
 
e1aa346
 
 
1dc4910
 
 
 
 
 
e1aa346
7397034
e1aa346
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb1158f
 
e1aa346
 
cb1158f
 
 
 
 
e1aa346
 
 
 
 
 
7397034
e1aa346
 
 
 
 
 
 
7397034
e1aa346
 
7397034
 
 
 
 
 
 
e1aa346
6b4af39
e1aa346
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import os
import io
import torch
from torch import nn 
from PIL import Image
import torchvision.utils as vutils
from fastapi import FastAPI, Response, HTTPException, Query
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from huggingface_hub import hf_hub_download, login

from models import Generator

app = FastAPI()

# CORS Configuration
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  
    allow_methods=["*"],
    allow_headers=["*"],
)

# Configuration Constants
Z_DIM = 100
DEVICE = torch.device("cpu")
REPO_ID = "SaniaE/GeoGen"
FILENAME = "dcgans_model_checkpoint.pt"
gen_model = None

@app.on_event("startup")
def load_model():
    global gen_model
    try:
        token = os.getenv("HF_Token")
        if token:
            login(token=token)
            print("Login successful.")
        else:
            print("No HF_TOKEN found - attempting public download.")

        model_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME, token=token)
        print(f"File downloaded to: {model_path}")

        checkpoint = torch.load(model_path, map_location=DEVICE)
        
        gen_model = Generator(z_dim=Z_DIM).to(DEVICE)
        missing, unexpected = gen_model.load_state_dict(
            checkpoint["gen_state_dict"], strict=False
        )

        print("Unexpected keys: ", unexpected)
        print("Missing keys: ", missing)
        gen_model.eval()
        print("SUCCESS: Petrol Pump GAN is live!")
    except Exception as e:
        print(f"Error loading model: {e}")


def postprocess_image(tensor):
    # Unnormalize: tanh output [-1, 1] -> [0, 1]
    img_tensor = (tensor + 1) / 2
    img_tensor = img_tensor.clamp(0, 1)

    grid = vutils.make_grid(img_tensor, padding=0, normalize=False)
    
    ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
    return Image.fromarray(ndarr)


def get_image_stream(tensor):
    """Helper to convert tensor to a streaming-ready PNG."""
    img_tensor = (tensor + 1) / 2
    img_tensor = img_tensor.clamp(0, 1)
    grid = vutils.make_grid(img_tensor, padding=0)
    ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
    
    pil_img = Image.fromarray(ndarr)
    buf = io.BytesIO()
    pil_img.save(buf, format="PNG")
    buf.seek(0)
    return buf


@app.get("/")
def read_root():
    return {"status": "online", "model": REPO_ID}


@app.get("/generate")
def generate_random(seed: int = Query(None)):
    """Endpoint 1: Fixed context generation for a session."""
    if gen_model is None: raise HTTPException(status_code=503)
    
    # Use the provided session seed or fallback to random
    active_seed = seed if seed is not None else torch.seed()
    torch.manual_seed(active_seed)
    
    with torch.inference_mode():
        noise = torch.randn(1, Z_DIM, device=DEVICE)
        fake_img = gen_model(noise)
        return StreamingResponse(get_image_stream(fake_img), media_type="image/png")


@app.get("/explore")
def explore_latent(seed: int, x_shift: float = Query(0.0, ge=-5.0, le=5.0), y_shift: float = Query(0.0, ge=-5.0, le=5.0)):
    """Endpoint 2: Controlled generation for 'Tuning'."""
    if gen_model is None: raise HTTPException(status_code=503)
    
    try:
        with torch.inference_mode():
            torch.manual_seed(seed)
            if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)

            noise = torch.randn(1, Z_DIM, device=DEVICE)

            # Structured control 
            noise[:, :10] += x_shift
            noise[:, 10:20] += y_shift

            # Random direction 
            direction = torch.randn_like(noise)
            noise = noise + 0.3 * direction * (abs(x_shift) + abs(y_shift))

            print("NOISE:", noise[0, :5])
            fake_img = gen_model(noise)
            return StreamingResponse(get_image_stream(fake_img), media_type="image/png")
            
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))