import gradio as gr import matplotlib.pyplot as plt import tensorflow as tf from huggingface_hub import snapshot_download from fastapi import FastAPI, HTTPException from pydantic import BaseModel import base64 import io import numpy as np from PIL import Image # Download and load model model_path = snapshot_download(repo_id="alexanderkroner/MSI-Net") loaded_model = tf.keras.layers.TFSMLayer(model_path, call_endpoint='serving_default') def get_target_shape(original_shape): original_aspect_ratio = original_shape[0] / original_shape[1] square_mode = abs(original_aspect_ratio - 1.0) landscape_mode = abs(original_aspect_ratio - 240 / 320) portrait_mode = abs(original_aspect_ratio - 320 / 240) best_mode = min(square_mode, landscape_mode, portrait_mode) if best_mode == square_mode: return (320, 320) elif best_mode == landscape_mode: return (240, 320) else: return (320, 240) def preprocess_input(input_image, target_shape): input_tensor = tf.expand_dims(input_image, axis=0) input_tensor = tf.image.resize(input_tensor, target_shape, preserve_aspect_ratio=True) vertical_padding = target_shape[0] - input_tensor.shape[1] horizontal_padding = target_shape[1] - input_tensor.shape[2] vertical_padding_1 = vertical_padding // 2 vertical_padding_2 = vertical_padding - vertical_padding_1 horizontal_padding_1 = horizontal_padding // 2 horizontal_padding_2 = horizontal_padding - horizontal_padding_1 input_tensor = tf.pad( input_tensor, [[0, 0], [vertical_padding_1, vertical_padding_2], [horizontal_padding_1, horizontal_padding_2], [0, 0]] ) return input_tensor, [vertical_padding_1, vertical_padding_2], [horizontal_padding_1, horizontal_padding_2] def postprocess_output(output_tensor, vertical_padding, horizontal_padding, original_shape): output_tensor = output_tensor[ :, vertical_padding[0]:output_tensor.shape[1] - vertical_padding[1], horizontal_padding[0]:output_tensor.shape[2] - horizontal_padding[1], : ] output_tensor = tf.image.resize(output_tensor, original_shape) output_array = output_tensor.numpy().squeeze() output_array = plt.cm.inferno(output_array)[..., :3] return output_array def compute_saliency(input_image, alpha=0.65): if input_image is not None: original_shape = input_image.shape[:2] target_shape = get_target_shape(original_shape) input_tensor, vertical_padding, horizontal_padding = preprocess_input(input_image, target_shape) saliency_map_dict = loaded_model(input_tensor) if "output" in saliency_map_dict: saliency_map = saliency_map_dict["output"] else: saliency_map = list(saliency_map_dict.values())[0] saliency_map = postprocess_output(saliency_map, vertical_padding, horizontal_padding, original_shape) blended_image = alpha * saliency_map + (1 - alpha) * input_image / 255 return blended_image # ============================================================================= # FastAPI endpoint for direct API access # ============================================================================= class SaliencyRequest(BaseModel): image_base64: str alpha: float = 0.65 app = FastAPI() @app.get("/api/status") async def api_status(): return {"status": "ok", "message": "Saliency API running. POST to /api/predict"} @app.post("/api/predict") async def api_predict(request: SaliencyRequest): try: # Decode base64 image image_data = base64.b64decode(request.image_base64) image = Image.open(io.BytesIO(image_data)) # Convert to numpy array image_array = np.array(image) # Ensure RGB if len(image_array.shape) == 2: image_array = np.stack([image_array] * 3, axis=-1) elif image_array.shape[2] == 4: image_array = image_array[:, :, :3] # Generate saliency map result = compute_saliency(image_array, request.alpha) # Convert result back to image result_image = (result * 255).astype(np.uint8) pil_image = Image.fromarray(result_image) # Convert to base64 buffered = io.BytesIO() pil_image.save(buffered, format="PNG") result_base64 = base64.b64encode(buffered.getvalue()).decode() return {"success": True, "saliency_map_base64": result_base64} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # ============================================================================= # Gradio interface (for UI) # ============================================================================= examples = [ "examples/kirsten-frank-o1sXiz_LU1A-unsplash.jpg", "examples/oscar-fickel-F5ze5FkEu1g-unsplash.jpg", "examples/ting-tian-_79ZJS8pV70-unsplash.jpg", "examples/gina-domenique-LmrAUrHinqk-unsplash.jpg", "examples/robby-mccullough-r05GkQBcaPM-unsplash.jpg", ] demo = gr.Interface( fn=compute_saliency, inputs=gr.Image(label="Input Image"), outputs=gr.Image(label="Saliency Map"), examples=examples, title="Visual Saliency Prediction", description="A demo to predict where humans fixate on an image using a deep learning model trained on eye movement data. Upload an image file, take a snapshot from your webcam, or paste an image from the clipboard to compute the saliency map.", article="For more information on the model, check out [GitHub](https://github.com/alexanderkroner/saliency) and the corresponding [paper](https://doi.org/10.1016/j.neunet.2020.05.004).", allow_flagging="never", api_name="predict" ) # Mount FastAPI to Gradio app = gr.mount_gradio_app(app, demo, path="/") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)