Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.templating import Jinja2Templates | |
| from fastapi.requests import Request | |
| from fastapi.responses import JSONResponse | |
| import numpy as np | |
| from PIL import Image | |
| import io | |
| import os | |
| import base64 | |
| import tensorflow as tf | |
| from tensorflow.keras import layers, Model | |
| app = FastAPI(title="SAR Image Colorization") | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| templates = Jinja2Templates(directory="templates") | |
| IMG_SIZE = (256, 256) | |
| MODEL_PATH = "sar_model.weights.h5" | |
| model = None | |
| def build_unet(input_shape=(256, 256, 1)): | |
| inputs = layers.Input(input_shape) | |
| def conv_block(x, filters): | |
| x = layers.Conv2D(filters, 3, padding="same", activation="relu")(x) | |
| x = layers.BatchNormalization()(x) | |
| x = layers.Conv2D(filters, 3, padding="same", activation="relu")(x) | |
| x = layers.BatchNormalization()(x) | |
| return x | |
| def encoder_block(x, filters): | |
| skip = conv_block(x, filters) | |
| pool = layers.MaxPooling2D(2)(skip) | |
| return skip, pool | |
| def decoder_block(x, skip, filters): | |
| x = layers.Conv2DTranspose(filters, 2, strides=2, padding="same")(x) | |
| x = layers.Concatenate()([x, skip]) | |
| x = conv_block(x, filters) | |
| return x | |
| s1, p1 = encoder_block(inputs, 32) | |
| s2, p2 = encoder_block(p1, 64) | |
| s3, p3 = encoder_block(p2, 128) | |
| s4, p4 = encoder_block(p3, 256) | |
| b = conv_block(p4, 512) | |
| d1 = decoder_block(b, s4, 256) | |
| d2 = decoder_block(d1, s3, 128) | |
| d3 = decoder_block(d2, s2, 64) | |
| d4 = decoder_block(d3, s1, 32) | |
| outputs = layers.Conv2D(3, 1, activation="tanh")(d4) | |
| return Model(inputs, outputs) | |
| async def load_model(): | |
| global model | |
| if os.path.exists(MODEL_PATH): | |
| model = build_unet() | |
| model(tf.zeros((1, 256, 256, 1))) # build model | |
| model.load_weights(MODEL_PATH) | |
| print("✅ SAR model loaded successfully") | |
| else: | |
| print("⚠️ Model not found — demo mode") | |
| def preprocess(image_bytes): | |
| img = Image.open(io.BytesIO(image_bytes)).convert("L") | |
| # Use LANCZOS for high quality resize | |
| img = img.resize(IMG_SIZE, Image.LANCZOS) | |
| arr = np.array(img, dtype=np.float32) / 127.5 - 1.0 | |
| return np.expand_dims(arr[..., np.newaxis], 0) | |
| def to_base64(arr_uint8): | |
| img = Image.fromarray(arr_uint8) | |
| buf = io.BytesIO() | |
| img.save(buf, format="PNG") | |
| return base64.b64encode(buf.getvalue()).decode() | |
| async def home(request: Request): | |
| return templates.TemplateResponse("index.html", {"request": request}) | |
| async def colorize(file: UploadFile = File(...)): | |
| if not file.content_type.startswith("image/"): | |
| raise HTTPException(status_code=400, detail="File must be an image.") | |
| contents = await file.read() | |
| if len(contents) > 10 * 1024 * 1024: | |
| raise HTTPException(status_code=400, detail="Image too large. Max 10MB.") | |
| try: | |
| inp = preprocess(contents) | |
| except Exception: | |
| raise HTTPException(status_code=400, detail="Could not process image.") | |
| if model is None: | |
| # Demo mode | |
| import random | |
| dummy = np.random.randint(50, 200, (256, 256, 3), dtype=np.uint8) | |
| dummy[:,:,0] = np.clip(dummy[:,:,0], 30, 100) | |
| dummy[:,:,1] = np.clip(dummy[:,:,1], 80, 180) | |
| dummy[:,:,2] = np.clip(dummy[:,:,2], 30, 100) | |
| pred_b64 = to_base64(dummy) | |
| else: | |
| pred = model.predict(inp, verbose=0)[0] | |
| pred_uint8 = ((pred + 1) * 127.5).clip(0, 255).astype(np.uint8) | |
| pred_b64 = to_base64(pred_uint8) | |
| # Also return input as base64 for display | |
| inp_disp = ((inp[0,:,:,0] + 1) * 127.5).clip(0, 255).astype(np.uint8) | |
| inp_b64 = to_base64(inp_disp) | |
| return JSONResponse({ | |
| "success": True, | |
| "input_b64": inp_b64, | |
| "output_b64": pred_b64 | |
| }) | |
| async def health(): | |
| return {"status": "ok", "model_loaded": model is not None} | |