| | import gradio as gr |
| | import matplotlib.pyplot as plt |
| | import tensorflow as tf |
| | from huggingface_hub import snapshot_download |
| | from fastapi import FastAPI, HTTPException |
| | from pydantic import BaseModel |
| | import base64 |
| | import io |
| | import numpy as np |
| | from PIL import Image |
| |
|
| | |
| | model_path = snapshot_download(repo_id="alexanderkroner/MSI-Net") |
| | loaded_model = tf.keras.layers.TFSMLayer(model_path, call_endpoint='serving_default') |
| |
|
| | def get_target_shape(original_shape): |
| | original_aspect_ratio = original_shape[0] / original_shape[1] |
| | square_mode = abs(original_aspect_ratio - 1.0) |
| | landscape_mode = abs(original_aspect_ratio - 240 / 320) |
| | portrait_mode = abs(original_aspect_ratio - 320 / 240) |
| | best_mode = min(square_mode, landscape_mode, portrait_mode) |
| | |
| | if best_mode == square_mode: |
| | return (320, 320) |
| | elif best_mode == landscape_mode: |
| | return (240, 320) |
| | else: |
| | return (320, 240) |
| |
|
| | def preprocess_input(input_image, target_shape): |
| | input_tensor = tf.expand_dims(input_image, axis=0) |
| | input_tensor = tf.image.resize(input_tensor, target_shape, preserve_aspect_ratio=True) |
| | |
| | vertical_padding = target_shape[0] - input_tensor.shape[1] |
| | horizontal_padding = target_shape[1] - input_tensor.shape[2] |
| | |
| | vertical_padding_1 = vertical_padding // 2 |
| | vertical_padding_2 = vertical_padding - vertical_padding_1 |
| | horizontal_padding_1 = horizontal_padding // 2 |
| | horizontal_padding_2 = horizontal_padding - horizontal_padding_1 |
| | |
| | input_tensor = tf.pad( |
| | input_tensor, |
| | [[0, 0], [vertical_padding_1, vertical_padding_2], |
| | [horizontal_padding_1, horizontal_padding_2], [0, 0]] |
| | ) |
| | |
| | return input_tensor, [vertical_padding_1, vertical_padding_2], [horizontal_padding_1, horizontal_padding_2] |
| |
|
| | def postprocess_output(output_tensor, vertical_padding, horizontal_padding, original_shape): |
| | output_tensor = output_tensor[ |
| | :, |
| | vertical_padding[0]:output_tensor.shape[1] - vertical_padding[1], |
| | horizontal_padding[0]:output_tensor.shape[2] - horizontal_padding[1], |
| | : |
| | ] |
| | output_tensor = tf.image.resize(output_tensor, original_shape) |
| | output_array = output_tensor.numpy().squeeze() |
| | output_array = plt.cm.inferno(output_array)[..., :3] |
| | return output_array |
| |
|
| | def compute_saliency(input_image, alpha=0.65): |
| | if input_image is not None: |
| | original_shape = input_image.shape[:2] |
| | target_shape = get_target_shape(original_shape) |
| | |
| | input_tensor, vertical_padding, horizontal_padding = preprocess_input(input_image, target_shape) |
| | |
| | saliency_map_dict = loaded_model(input_tensor) |
| | if "output" in saliency_map_dict: |
| | saliency_map = saliency_map_dict["output"] |
| | else: |
| | saliency_map = list(saliency_map_dict.values())[0] |
| | |
| | saliency_map = postprocess_output(saliency_map, vertical_padding, horizontal_padding, original_shape) |
| | blended_image = alpha * saliency_map + (1 - alpha) * input_image / 255 |
| | |
| | return blended_image |
| |
|
| | |
| | |
| | |
| |
|
| | class SaliencyRequest(BaseModel): |
| | image_base64: str |
| | alpha: float = 0.65 |
| |
|
| | app = FastAPI() |
| |
|
| | @app.get("/api/status") |
| | async def api_status(): |
| | return {"status": "ok", "message": "Saliency API running. POST to /api/predict"} |
| |
|
| | @app.post("/api/predict") |
| | async def api_predict(request: SaliencyRequest): |
| | try: |
| | |
| | image_data = base64.b64decode(request.image_base64) |
| | image = Image.open(io.BytesIO(image_data)) |
| | |
| | |
| | image_array = np.array(image) |
| | |
| | |
| | if len(image_array.shape) == 2: |
| | image_array = np.stack([image_array] * 3, axis=-1) |
| | elif image_array.shape[2] == 4: |
| | image_array = image_array[:, :, :3] |
| | |
| | |
| | result = compute_saliency(image_array, request.alpha) |
| | |
| | |
| | result_image = (result * 255).astype(np.uint8) |
| | pil_image = Image.fromarray(result_image) |
| | |
| | |
| | buffered = io.BytesIO() |
| | pil_image.save(buffered, format="PNG") |
| | result_base64 = base64.b64encode(buffered.getvalue()).decode() |
| | |
| | return {"success": True, "saliency_map_base64": result_base64} |
| | |
| | except Exception as e: |
| | raise HTTPException(status_code=500, detail=str(e)) |
| |
|
| | |
| | |
| | |
| |
|
| | examples = [ |
| | "examples/kirsten-frank-o1sXiz_LU1A-unsplash.jpg", |
| | "examples/oscar-fickel-F5ze5FkEu1g-unsplash.jpg", |
| | "examples/ting-tian-_79ZJS8pV70-unsplash.jpg", |
| | "examples/gina-domenique-LmrAUrHinqk-unsplash.jpg", |
| | "examples/robby-mccullough-r05GkQBcaPM-unsplash.jpg", |
| | ] |
| |
|
| | demo = gr.Interface( |
| | fn=compute_saliency, |
| | inputs=gr.Image(label="Input Image"), |
| | outputs=gr.Image(label="Saliency Map"), |
| | examples=examples, |
| | title="Visual Saliency Prediction", |
| | description="A demo to predict where humans fixate on an image using a deep learning model trained on eye movement data. Upload an image file, take a snapshot from your webcam, or paste an image from the clipboard to compute the saliency map.", |
| | article="For more information on the model, check out [GitHub](https://github.com/alexanderkroner/saliency) and the corresponding [paper](https://doi.org/10.1016/j.neunet.2020.05.004).", |
| | allow_flagging="never", |
| | api_name="predict" |
| | ) |
| |
|
| | |
| | app = gr.mount_gradio_app(app, demo, path="/") |
| |
|
| | if __name__ == "__main__": |
| | import uvicorn |
| | uvicorn.run(app, host="0.0.0.0", port=7860) |