egoisv commited on
Commit
59307c8
·
verified ·
1 Parent(s): f804c41

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -48
app.py CHANGED
@@ -1,19 +1,17 @@
1
- from fastapi import FastAPI, HTTPException
2
- from fastapi.responses import JSONResponse
3
- from pydantic import BaseModel
4
  import tensorflow as tf
5
  from huggingface_hub import snapshot_download
6
- import matplotlib.pyplot as plt
 
7
  import base64
8
  import io
9
  import numpy as np
10
  from PIL import Image
11
 
12
  # Download and load model
13
- print("Loading model...")
14
  model_path = snapshot_download(repo_id="alexanderkroner/MSI-Net")
15
  loaded_model = tf.keras.layers.TFSMLayer(model_path, call_endpoint='serving_default')
16
- print("Model loaded!")
17
 
18
  def get_target_shape(original_shape):
19
  original_aspect_ratio = original_shape[0] / original_shape[1]
@@ -61,21 +59,41 @@ def postprocess_output(output_tensor, vertical_padding, horizontal_padding, orig
61
  output_array = plt.cm.inferno(output_array)[..., :3]
62
  return output_array
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  class SaliencyRequest(BaseModel):
65
  image_base64: str
66
  alpha: float = 0.65
67
 
68
- app = FastAPI(title="Saliency Map API")
69
 
70
- @app.get("/")
71
- async def root():
72
- return {"status": "ok", "message": "Saliency Map API is running. POST to /predict with image_base64"}
73
 
74
- @app.post("/predict")
75
- async def generate_saliency(request: SaliencyRequest):
76
  try:
77
- print(f"Received request, image size: {len(request.image_base64)} chars")
78
-
79
  # Decode base64 image
80
  image_data = base64.b64decode(request.image_base64)
81
  image = Image.open(io.BytesIO(image_data))
@@ -89,32 +107,11 @@ async def generate_saliency(request: SaliencyRequest):
89
  elif image_array.shape[2] == 4:
90
  image_array = image_array[:, :, :3]
91
 
92
- print(f"Image shape: {image_array.shape}")
93
-
94
- # Get target shape
95
- original_shape = image_array.shape[:2]
96
- target_shape = get_target_shape(original_shape)
97
-
98
- # Preprocess
99
- input_tensor, vertical_padding, horizontal_padding = preprocess_input(image_array, target_shape)
100
-
101
- # Run model
102
- print("Running inference...")
103
- saliency_map_dict = loaded_model(input_tensor)
104
 
105
- if "output" in saliency_map_dict:
106
- saliency_map = saliency_map_dict["output"]
107
- else:
108
- saliency_map = list(saliency_map_dict.values())[0]
109
-
110
- # Postprocess
111
- saliency_map = postprocess_output(saliency_map, vertical_padding, horizontal_padding, original_shape)
112
-
113
- # Blend
114
- blended_image = request.alpha * saliency_map + (1 - request.alpha) * image_array / 255
115
-
116
- # Convert to image
117
- result_image = (blended_image * 255).astype(np.uint8)
118
  pil_image = Image.fromarray(result_image)
119
 
120
  # Convert to base64
@@ -122,19 +119,38 @@ async def generate_saliency(request: SaliencyRequest):
122
  pil_image.save(buffered, format="PNG")
123
  result_base64 = base64.b64encode(buffered.getvalue()).decode()
124
 
125
- print(f"Success! Result size: {len(result_base64)} chars")
126
-
127
- return JSONResponse({
128
- "success": True,
129
- "saliency_map_base64": result_base64
130
- })
131
 
132
  except Exception as e:
133
- print(f"Error: {str(e)}")
134
- import traceback
135
- traceback.print_exc()
136
  raise HTTPException(status_code=500, detail=str(e))
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  if __name__ == "__main__":
139
  import uvicorn
140
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ import gradio as gr
2
+ import matplotlib.pyplot as plt
 
3
  import tensorflow as tf
4
  from huggingface_hub import snapshot_download
5
+ from fastapi import FastAPI, HTTPException
6
+ from pydantic import BaseModel
7
  import base64
8
  import io
9
  import numpy as np
10
  from PIL import Image
11
 
12
  # Download and load model
 
13
  model_path = snapshot_download(repo_id="alexanderkroner/MSI-Net")
14
  loaded_model = tf.keras.layers.TFSMLayer(model_path, call_endpoint='serving_default')
 
15
 
16
  def get_target_shape(original_shape):
17
  original_aspect_ratio = original_shape[0] / original_shape[1]
 
59
  output_array = plt.cm.inferno(output_array)[..., :3]
60
  return output_array
61
 
62
+ def compute_saliency(input_image, alpha=0.65):
63
+ if input_image is not None:
64
+ original_shape = input_image.shape[:2]
65
+ target_shape = get_target_shape(original_shape)
66
+
67
+ input_tensor, vertical_padding, horizontal_padding = preprocess_input(input_image, target_shape)
68
+
69
+ saliency_map_dict = loaded_model(input_tensor)
70
+ if "output" in saliency_map_dict:
71
+ saliency_map = saliency_map_dict["output"]
72
+ else:
73
+ saliency_map = list(saliency_map_dict.values())[0]
74
+
75
+ saliency_map = postprocess_output(saliency_map, vertical_padding, horizontal_padding, original_shape)
76
+ blended_image = alpha * saliency_map + (1 - alpha) * input_image / 255
77
+
78
+ return blended_image
79
+
80
+ # =============================================================================
81
+ # FastAPI endpoint for direct API access
82
+ # =============================================================================
83
+
84
  class SaliencyRequest(BaseModel):
85
  image_base64: str
86
  alpha: float = 0.65
87
 
88
+ app = FastAPI()
89
 
90
+ @app.get("/api/status")
91
+ async def api_status():
92
+ return {"status": "ok", "message": "Saliency API running. POST to /api/predict"}
93
 
94
+ @app.post("/api/predict")
95
+ async def api_predict(request: SaliencyRequest):
96
  try:
 
 
97
  # Decode base64 image
98
  image_data = base64.b64decode(request.image_base64)
99
  image = Image.open(io.BytesIO(image_data))
 
107
  elif image_array.shape[2] == 4:
108
  image_array = image_array[:, :, :3]
109
 
110
+ # Generate saliency map
111
+ result = compute_saliency(image_array, request.alpha)
 
 
 
 
 
 
 
 
 
 
112
 
113
+ # Convert result back to image
114
+ result_image = (result * 255).astype(np.uint8)
 
 
 
 
 
 
 
 
 
 
 
115
  pil_image = Image.fromarray(result_image)
116
 
117
  # Convert to base64
 
119
  pil_image.save(buffered, format="PNG")
120
  result_base64 = base64.b64encode(buffered.getvalue()).decode()
121
 
122
+ return {"success": True, "saliency_map_base64": result_base64}
 
 
 
 
 
123
 
124
  except Exception as e:
 
 
 
125
  raise HTTPException(status_code=500, detail=str(e))
126
 
127
+ # =============================================================================
128
+ # Gradio interface (for UI)
129
+ # =============================================================================
130
+
131
+ examples = [
132
+ "examples/kirsten-frank-o1sXiz_LU1A-unsplash.jpg",
133
+ "examples/oscar-fickel-F5ze5FkEu1g-unsplash.jpg",
134
+ "examples/ting-tian-_79ZJS8pV70-unsplash.jpg",
135
+ "examples/gina-domenique-LmrAUrHinqk-unsplash.jpg",
136
+ "examples/robby-mccullough-r05GkQBcaPM-unsplash.jpg",
137
+ ]
138
+
139
+ demo = gr.Interface(
140
+ fn=compute_saliency,
141
+ inputs=gr.Image(label="Input Image"),
142
+ outputs=gr.Image(label="Saliency Map"),
143
+ examples=examples,
144
+ title="Visual Saliency Prediction",
145
+ description="A demo to predict where humans fixate on an image using a deep learning model trained on eye movement data. Upload an image file, take a snapshot from your webcam, or paste an image from the clipboard to compute the saliency map.",
146
+ article="For more information on the model, check out [GitHub](https://github.com/alexanderkroner/saliency) and the corresponding [paper](https://doi.org/10.1016/j.neunet.2020.05.004).",
147
+ allow_flagging="never",
148
+ api_name="predict"
149
+ )
150
+
151
+ # Mount FastAPI to Gradio
152
+ app = gr.mount_gradio_app(app, demo, path="/")
153
+
154
  if __name__ == "__main__":
155
  import uvicorn
156
  uvicorn.run(app, host="0.0.0.0", port=7860)