sar-colorization / main.py
saann's picture
SAR Image Colorization — U-Net FastAPI
4bb6657
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from fastapi.requests import Request
from fastapi.responses import JSONResponse
import numpy as np
from PIL import Image
import io
import os
import base64
import tensorflow as tf
from tensorflow.keras import layers, Model
app = FastAPI(title="SAR Image Colorization")
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
IMG_SIZE = (256, 256)
MODEL_PATH = "sar_model.weights.h5"
model = None
def build_unet(input_shape=(256, 256, 1)):
inputs = layers.Input(input_shape)
def conv_block(x, filters):
x = layers.Conv2D(filters, 3, padding="same", activation="relu")(x)
x = layers.BatchNormalization()(x)
x = layers.Conv2D(filters, 3, padding="same", activation="relu")(x)
x = layers.BatchNormalization()(x)
return x
def encoder_block(x, filters):
skip = conv_block(x, filters)
pool = layers.MaxPooling2D(2)(skip)
return skip, pool
def decoder_block(x, skip, filters):
x = layers.Conv2DTranspose(filters, 2, strides=2, padding="same")(x)
x = layers.Concatenate()([x, skip])
x = conv_block(x, filters)
return x
s1, p1 = encoder_block(inputs, 32)
s2, p2 = encoder_block(p1, 64)
s3, p3 = encoder_block(p2, 128)
s4, p4 = encoder_block(p3, 256)
b = conv_block(p4, 512)
d1 = decoder_block(b, s4, 256)
d2 = decoder_block(d1, s3, 128)
d3 = decoder_block(d2, s2, 64)
d4 = decoder_block(d3, s1, 32)
outputs = layers.Conv2D(3, 1, activation="tanh")(d4)
return Model(inputs, outputs)
@app.on_event("startup")
async def load_model():
global model
if os.path.exists(MODEL_PATH):
model = build_unet()
model(tf.zeros((1, 256, 256, 1))) # build model
model.load_weights(MODEL_PATH)
print("✅ SAR model loaded successfully")
else:
print("⚠️ Model not found — demo mode")
def preprocess(image_bytes):
img = Image.open(io.BytesIO(image_bytes)).convert("L")
# Use LANCZOS for high quality resize
img = img.resize(IMG_SIZE, Image.LANCZOS)
arr = np.array(img, dtype=np.float32) / 127.5 - 1.0
return np.expand_dims(arr[..., np.newaxis], 0)
def to_base64(arr_uint8):
img = Image.fromarray(arr_uint8)
buf = io.BytesIO()
img.save(buf, format="PNG")
return base64.b64encode(buf.getvalue()).decode()
@app.get("/")
async def home(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.post("/colorize")
async def colorize(file: UploadFile = File(...)):
if not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="File must be an image.")
contents = await file.read()
if len(contents) > 10 * 1024 * 1024:
raise HTTPException(status_code=400, detail="Image too large. Max 10MB.")
try:
inp = preprocess(contents)
except Exception:
raise HTTPException(status_code=400, detail="Could not process image.")
if model is None:
# Demo mode
import random
dummy = np.random.randint(50, 200, (256, 256, 3), dtype=np.uint8)
dummy[:,:,0] = np.clip(dummy[:,:,0], 30, 100)
dummy[:,:,1] = np.clip(dummy[:,:,1], 80, 180)
dummy[:,:,2] = np.clip(dummy[:,:,2], 30, 100)
pred_b64 = to_base64(dummy)
else:
pred = model.predict(inp, verbose=0)[0]
pred_uint8 = ((pred + 1) * 127.5).clip(0, 255).astype(np.uint8)
pred_b64 = to_base64(pred_uint8)
# Also return input as base64 for display
inp_disp = ((inp[0,:,:,0] + 1) * 127.5).clip(0, 255).astype(np.uint8)
inp_b64 = to_base64(inp_disp)
return JSONResponse({
"success": True,
"input_b64": inp_b64,
"output_b64": pred_b64
})
@app.get("/health")
async def health():
return {"status": "ok", "model_loaded": model is not None}