Spaces:
Running
Running
File size: 4,055 Bytes
e1aa346 b172d7f e1aa346 2aac227 b172d7f e1aa346 b172d7f e1aa346 1dc4910 e1aa346 7397034 e1aa346 cb1158f e1aa346 cb1158f e1aa346 7397034 e1aa346 7397034 e1aa346 7397034 e1aa346 6b4af39 e1aa346 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | import os
import io
import torch
from torch import nn
from PIL import Image
import torchvision.utils as vutils
from fastapi import FastAPI, Response, HTTPException, Query
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from huggingface_hub import hf_hub_download, login
from models import Generator
app = FastAPI()
# CORS Configuration
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Configuration Constants
Z_DIM = 100
DEVICE = torch.device("cpu")
REPO_ID = "SaniaE/GeoGen"
FILENAME = "dcgans_model_checkpoint.pt"
gen_model = None
@app.on_event("startup")
def load_model():
global gen_model
try:
token = os.getenv("HF_Token")
if token:
login(token=token)
print("Login successful.")
else:
print("No HF_TOKEN found - attempting public download.")
model_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME, token=token)
print(f"File downloaded to: {model_path}")
checkpoint = torch.load(model_path, map_location=DEVICE)
gen_model = Generator(z_dim=Z_DIM).to(DEVICE)
missing, unexpected = gen_model.load_state_dict(
checkpoint["gen_state_dict"], strict=False
)
print("Unexpected keys: ", unexpected)
print("Missing keys: ", missing)
gen_model.eval()
print("SUCCESS: Petrol Pump GAN is live!")
except Exception as e:
print(f"Error loading model: {e}")
def postprocess_image(tensor):
# Unnormalize: tanh output [-1, 1] -> [0, 1]
img_tensor = (tensor + 1) / 2
img_tensor = img_tensor.clamp(0, 1)
grid = vutils.make_grid(img_tensor, padding=0, normalize=False)
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
return Image.fromarray(ndarr)
def get_image_stream(tensor):
"""Helper to convert tensor to a streaming-ready PNG."""
img_tensor = (tensor + 1) / 2
img_tensor = img_tensor.clamp(0, 1)
grid = vutils.make_grid(img_tensor, padding=0)
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
pil_img = Image.fromarray(ndarr)
buf = io.BytesIO()
pil_img.save(buf, format="PNG")
buf.seek(0)
return buf
@app.get("/")
def read_root():
return {"status": "online", "model": REPO_ID}
@app.get("/generate")
def generate_random(seed: int = Query(None)):
"""Endpoint 1: Fixed context generation for a session."""
if gen_model is None: raise HTTPException(status_code=503)
# Use the provided session seed or fallback to random
active_seed = seed if seed is not None else torch.seed()
torch.manual_seed(active_seed)
with torch.inference_mode():
noise = torch.randn(1, Z_DIM, device=DEVICE)
fake_img = gen_model(noise)
return StreamingResponse(get_image_stream(fake_img), media_type="image/png")
@app.get("/explore")
def explore_latent(seed: int, x_shift: float = Query(0.0, ge=-5.0, le=5.0), y_shift: float = Query(0.0, ge=-5.0, le=5.0)):
"""Endpoint 2: Controlled generation for 'Tuning'."""
if gen_model is None: raise HTTPException(status_code=503)
try:
with torch.inference_mode():
torch.manual_seed(seed)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
noise = torch.randn(1, Z_DIM, device=DEVICE)
# Structured control
noise[:, :10] += x_shift
noise[:, 10:20] += y_shift
# Random direction
direction = torch.randn_like(noise)
noise = noise + 0.3 * direction * (abs(x_shift) + abs(y_shift))
print("NOISE:", noise[0, :5])
fake_img = gen_model(noise)
return StreamingResponse(get_image_stream(fake_img), media_type="image/png")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) |