Spaces:
Build error
Build error
umyuu commited on
Commit ·
da8bdb9
1
Parent(s): 271d94c
refactoring
Browse files- App.pyの処理内容をSaliencyMapクラスに分割。
- 実行ログ出力用としてreporter.pyを追加。
- src/app.py +80 -46
- src/launch.py +7 -4
- src/reporter.py +49 -0
- src/saliency.py +75 -0
src/app.py
CHANGED
|
@@ -6,54 +6,73 @@ import argparse
|
|
| 6 |
from datetime import datetime
|
| 7 |
import sys
|
| 8 |
|
| 9 |
-
import cv2
|
| 10 |
import gradio as gr
|
| 11 |
import numpy as np
|
| 12 |
|
| 13 |
import utils
|
| 14 |
-
|
|
|
|
|
|
|
| 15 |
PROGRAM_NAME = 'SaliencyMapDemo'
|
| 16 |
__version__ = utils.get_package_version()
|
|
|
|
| 17 |
|
| 18 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
"""
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
np.ndarray
|
| 29 |
-
カラーマップのJET画像
|
| 30 |
"""
|
| 31 |
-
#
|
| 32 |
-
|
| 33 |
-
|
|
|
|
| 34 |
success, saliencyMap = saliency.computeSaliency(image)
|
|
|
|
| 35 |
|
| 36 |
-
if success:
|
| 37 |
-
#
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
|
| 49 |
def run(args: argparse.Namespace, watch: utils.Stopwatch) -> None:
|
| 50 |
"""
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
コマンドライン引数
|
| 55 |
-
|
| 56 |
-
起動したスタート時間
|
| 57 |
"""
|
| 58 |
# analytics_enabled=False
|
| 59 |
# https://github.com/gradio-app/gradio/issues/4226
|
|
@@ -68,22 +87,32 @@ def run(args: argparse.Namespace, watch: utils.Stopwatch) -> None:
|
|
| 68 |
gr.Markdown(
|
| 69 |
"""
|
| 70 |
# Saliency Map demo.
|
| 71 |
-
1. inputタブで画像を選択します。
|
| 72 |
-
2. Submitボタンを押します。
|
| 73 |
-
※画像は外部送信していません。ローカルで処理が完結します。
|
| 74 |
-
3. 結果は、overlayタブに表示します。
|
| 75 |
""")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
submit_button = gr.Button("submit")
|
| 78 |
-
|
| 79 |
with gr.Row():
|
| 80 |
-
with gr.Tab("input"):
|
| 81 |
-
image_input = gr.Image()
|
| 82 |
-
with gr.Tab("overlay"):
|
| 83 |
-
image_overlay = gr.Image(interactive=False)
|
| 84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
|
| 86 |
-
submit_button.click(
|
| 87 |
|
| 88 |
gr.Markdown(
|
| 89 |
f"""
|
|
@@ -93,7 +122,12 @@ def run(args: argparse.Namespace, watch: utils.Stopwatch) -> None:
|
|
| 93 |
|
| 94 |
demo.queue(default_concurrency_limit=5)
|
| 95 |
|
| 96 |
-
|
| 97 |
|
| 98 |
# https://www.gradio.app/docs/gradio/blocks#blocks-launch
|
| 99 |
-
demo.launch(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
from datetime import datetime
|
| 7 |
import sys
|
| 8 |
|
|
|
|
| 9 |
import gradio as gr
|
| 10 |
import numpy as np
|
| 11 |
|
| 12 |
import utils
|
| 13 |
+
from saliency import SaliencyMap, convertColorMap
|
| 14 |
+
from reporter import get_current_reporter
|
| 15 |
+
|
| 16 |
PROGRAM_NAME = 'SaliencyMapDemo'
|
| 17 |
__version__ = utils.get_package_version()
|
| 18 |
+
log = get_current_reporter()
|
| 19 |
|
| 20 |
+
def jetTab_Selected(image: np.ndarray):
|
| 21 |
+
#print(f"{datetime.now()}#jet")
|
| 22 |
+
saliency = SaliencyMap("SpectralResidual")
|
| 23 |
+
success, saliencyMap = saliency.computeSaliency(image)
|
| 24 |
+
retval = convertColorMap(image, saliencyMap, "jet")
|
| 25 |
+
#print(f"{datetime.now()}#jet")
|
| 26 |
+
|
| 27 |
+
return retval
|
| 28 |
+
|
| 29 |
+
def hotTab_Selected(image: np.ndarray):
|
| 30 |
+
#print(f"{datetime.now()}#hot")
|
| 31 |
+
saliency = SaliencyMap("SpectralResidual")
|
| 32 |
+
success, saliencyMap = saliency.computeSaliency(image)
|
| 33 |
+
retval = convertColorMap(image, saliencyMap, "hot")
|
| 34 |
+
#print(f"{datetime.now()}#hot")
|
| 35 |
+
|
| 36 |
+
return retval
|
| 37 |
+
|
| 38 |
+
def submit_Clicked(image: np.ndarray, algorithm: str):
|
| 39 |
"""
|
| 40 |
+
入力画像を元に顕著マップを計算します。
|
| 41 |
+
|
| 42 |
+
Parameters:
|
| 43 |
+
image: 入力画像
|
| 44 |
+
str: 顕著性マップのアルゴリズム
|
| 45 |
+
Returns:
|
| 46 |
+
np.ndarray: JET画像
|
| 47 |
+
np.ndarray: HOT画像
|
|
|
|
|
|
|
| 48 |
"""
|
| 49 |
+
log.info(f"#submit_Clicked")
|
| 50 |
+
watch = utils.Stopwatch.startNew()
|
| 51 |
+
|
| 52 |
+
saliency = SaliencyMap(algorithm)
|
| 53 |
success, saliencyMap = saliency.computeSaliency(image)
|
| 54 |
+
log.info(f"#SaliencyMap computeSaliency()")
|
| 55 |
|
| 56 |
+
if not success:
|
| 57 |
+
return image, image # エラーが発生した場合は入力画像を返します。
|
| 58 |
+
|
| 59 |
+
log.info(f"#jet")
|
| 60 |
+
jet = convertColorMap(image, saliencyMap, "jet")
|
| 61 |
+
#jet = None
|
| 62 |
+
log.info(f"#hot")
|
| 63 |
+
hot = convertColorMap(image, saliencyMap, "hot")
|
| 64 |
+
|
| 65 |
+
saliency = None
|
| 66 |
+
log.info(f"#submit_Clicked End{watch.stop():.3f}")
|
| 67 |
+
return jet, hot
|
| 68 |
|
| 69 |
def run(args: argparse.Namespace, watch: utils.Stopwatch) -> None:
|
| 70 |
"""
|
| 71 |
+
アプリの画面を作成し、Gradioサービスを起動します。
|
| 72 |
+
|
| 73 |
+
Parameters:
|
| 74 |
+
args: コマンドライン引数
|
| 75 |
+
watch: 起動したスタート時間
|
|
|
|
| 76 |
"""
|
| 77 |
# analytics_enabled=False
|
| 78 |
# https://github.com/gradio-app/gradio/issues/4226
|
|
|
|
| 87 |
gr.Markdown(
|
| 88 |
"""
|
| 89 |
# Saliency Map demo.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
""")
|
| 91 |
+
with gr.Accordion("取り扱い説明書", open=False):
|
| 92 |
+
gr.Markdown(
|
| 93 |
+
"""
|
| 94 |
+
1. inputタブで画像を選択します。
|
| 95 |
+
2. Submitボタンを押します。
|
| 96 |
+
※画像は外部送信していません。ローカルで処理が完結します。
|
| 97 |
+
3. 結果は、JETタブとHOTタブに表示します。
|
| 98 |
+
""")
|
| 99 |
+
|
| 100 |
+
algorithmType = gr.Radio(["SpectralResidual", "FineGrained"], label="Saliency", value="SpectralResidual", interactive=True)
|
| 101 |
|
| 102 |
submit_button = gr.Button("submit")
|
| 103 |
+
|
| 104 |
with gr.Row():
|
| 105 |
+
with gr.Tab("input", id="input"):
|
|
|
|
|
|
|
|
|
|
| 106 |
|
| 107 |
+
image_input = gr.Image(sources = ["upload", "clipboard"], interactive=True)
|
| 108 |
+
with gr.Tab("overlay(JET)"):
|
| 109 |
+
image_overlay_jet = gr.Image(interactive=False)
|
| 110 |
+
#tab_jet.select(jetTab_Selected, inputs=[image_input], outputs=image_overlay_jet)
|
| 111 |
+
with gr.Tab("overlay(HOT)"):
|
| 112 |
+
image_overlay_hot = gr.Image(interactive=False)
|
| 113 |
+
#tab_hot.select(hotTab_Selected, inputs=[image_input], outputs=image_overlay_hot, api_name=False)
|
| 114 |
|
| 115 |
+
submit_button.click(submit_Clicked, inputs=[image_input, algorithmType], outputs=[image_overlay_jet, image_overlay_hot])
|
| 116 |
|
| 117 |
gr.Markdown(
|
| 118 |
f"""
|
|
|
|
| 122 |
|
| 123 |
demo.queue(default_concurrency_limit=5)
|
| 124 |
|
| 125 |
+
log.info(f"#アプリ起動完了({watch.stop():.3f}s)")
|
| 126 |
|
| 127 |
# https://www.gradio.app/docs/gradio/blocks#blocks-launch
|
| 128 |
+
demo.launch(
|
| 129 |
+
max_file_size=args.max_file_size,
|
| 130 |
+
server_port=args.server_port,
|
| 131 |
+
inbrowser=True,
|
| 132 |
+
share=False,
|
| 133 |
+
)
|
src/launch.py
CHANGED
|
@@ -6,15 +6,18 @@ import argparse
|
|
| 6 |
from datetime import datetime
|
| 7 |
|
| 8 |
from utils import get_package_version, Stopwatch
|
|
|
|
| 9 |
|
| 10 |
def main():
|
| 11 |
"""
|
| 12 |
エントリーポイント
|
| 13 |
-
コマンドライン引数の解析を行います
|
|
|
|
| 14 |
"""
|
| 15 |
-
|
|
|
|
| 16 |
watch = Stopwatch.startNew()
|
| 17 |
-
|
| 18 |
import app
|
| 19 |
|
| 20 |
parser = argparse.ArgumentParser(prog=app.PROGRAM_NAME, description="SaliencyMapDemo")
|
|
@@ -25,4 +28,4 @@ def main():
|
|
| 25 |
app.run(parser.parse_args(), watch)
|
| 26 |
|
| 27 |
if __name__ == "__main__":
|
| 28 |
-
|
|
|
|
| 6 |
from datetime import datetime
|
| 7 |
|
| 8 |
from utils import get_package_version, Stopwatch
|
| 9 |
+
from reporter import get_current_reporter
|
| 10 |
|
| 11 |
def main():
|
| 12 |
"""
|
| 13 |
エントリーポイント
|
| 14 |
+
1, コマンドライン引数の解析を行います
|
| 15 |
+
2, アプリを起動します。
|
| 16 |
"""
|
| 17 |
+
log = get_current_reporter()
|
| 18 |
+
log.info("#アプリ起動中")
|
| 19 |
watch = Stopwatch.startNew()
|
| 20 |
+
|
| 21 |
import app
|
| 22 |
|
| 23 |
parser = argparse.ArgumentParser(prog=app.PROGRAM_NAME, description="SaliencyMapDemo")
|
|
|
|
| 28 |
app.run(parser.parse_args(), watch)
|
| 29 |
|
| 30 |
if __name__ == "__main__":
|
| 31 |
+
main()
|
src/reporter.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
Reporter
|
| 4 |
+
ログハンドラーが重複登録されるのを防ぐために1箇所で生成してログハンドラーを返します。
|
| 5 |
+
Example:
|
| 6 |
+
from reporter import get_current_reporter
|
| 7 |
+
|
| 8 |
+
logger = get_current_reporter()
|
| 9 |
+
logger.info("message");
|
| 10 |
+
"""
|
| 11 |
+
from logging import Logger, getLogger, Formatter, StreamHandler
|
| 12 |
+
from logging import DEBUG
|
| 13 |
+
|
| 14 |
+
_reporters = []
|
| 15 |
+
|
| 16 |
+
def get_current_reporter() -> Logger:
|
| 17 |
+
return _reporters[-1]
|
| 18 |
+
|
| 19 |
+
def __make_reporter(name: str='SaliencyMapDemo') -> None:
|
| 20 |
+
"""
|
| 21 |
+
ログハンドラーを生成します。
|
| 22 |
+
@see https://docs.python.jp/3/howto/logging-cookbook.html#logging-to-a-single-file-from-multiple-processes
|
| 23 |
+
|
| 24 |
+
Parameters:
|
| 25 |
+
name: アプリ名
|
| 26 |
+
"""
|
| 27 |
+
handler = StreamHandler() # コンソールに出力します。
|
| 28 |
+
formatter = Formatter('%(asctime)s%(message)s')
|
| 29 |
+
handler.setFormatter(formatter)
|
| 30 |
+
handler.setLevel(DEBUG)
|
| 31 |
+
|
| 32 |
+
logger = getLogger(name)
|
| 33 |
+
logger.setLevel(DEBUG)
|
| 34 |
+
logger.addHandler(handler)
|
| 35 |
+
_reporters.append(logger)
|
| 36 |
+
|
| 37 |
+
__make_reporter()
|
| 38 |
+
|
| 39 |
+
def main():
|
| 40 |
+
"""
|
| 41 |
+
Entry Point
|
| 42 |
+
"""
|
| 43 |
+
assert len(_reporters) == 1
|
| 44 |
+
|
| 45 |
+
logger = get_current_reporter()
|
| 46 |
+
logger.debug("main")
|
| 47 |
+
|
| 48 |
+
if __name__ == "__main__":
|
| 49 |
+
main()
|
src/saliency.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
from typing import Literal
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import cv2
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class SaliencyMap:
|
| 9 |
+
"""
|
| 10 |
+
SaliencyMap 顕著性マップを計算するクラスです。
|
| 11 |
+
Example:
|
| 12 |
+
from lib.saliency import SaliencyMap
|
| 13 |
+
|
| 14 |
+
saliency = SaliencyMap("SpectralResidual")
|
| 15 |
+
success, saliencyMap = saliency.computeSaliency(image)
|
| 16 |
+
"""
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
algorithm: Literal["SpectralResidual", "FineGrained"] = "SpectralResidual",
|
| 20 |
+
):
|
| 21 |
+
self.algorithm = algorithm
|
| 22 |
+
# OpenCVのsaliencyを作成します。
|
| 23 |
+
if algorithm == "SpectralResidual":
|
| 24 |
+
self.saliency = cv2.saliency.StaticSaliencySpectralResidual_create()
|
| 25 |
+
else:
|
| 26 |
+
self.saliency = cv2.saliency.StaticSaliencyFineGrained_create()
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def computeSaliency(self, image: np.ndarray):
|
| 30 |
+
"""
|
| 31 |
+
入力画像から顕著性マップを作成します。
|
| 32 |
+
|
| 33 |
+
Parameters:
|
| 34 |
+
image: 入力画像
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
bool:
|
| 38 |
+
true: SaliencyMap computed, false:NG
|
| 39 |
+
np.ndarray: 顕著性マップ
|
| 40 |
+
"""
|
| 41 |
+
# 画像の顕著性を計算します。
|
| 42 |
+
return self.saliency.computeSaliency(image)
|
| 43 |
+
|
| 44 |
+
def convertColorMap(
|
| 45 |
+
image: np.ndarray,
|
| 46 |
+
saliencyMap: np.ndarray,
|
| 47 |
+
colormap_name: Literal["jet", "hot"] = "jet"):
|
| 48 |
+
"""
|
| 49 |
+
顕著性マップをカラーマップに変換後に、入力画像に重ね合わせします。
|
| 50 |
+
|
| 51 |
+
Parameters:
|
| 52 |
+
image: 入力画像
|
| 53 |
+
saliencyMap: 顕著性マップ
|
| 54 |
+
colormap_name: カラーマップの種類
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
np.ndarray: 重ね合わせた画像(RGBA形式)
|
| 58 |
+
"""
|
| 59 |
+
#image = (image * 255).astype("uint8")
|
| 60 |
+
#
|
| 61 |
+
#return cv2.applyColorMap(image, cv2.COLORMAP_JET)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# 顕著性マップをカラーマップに変換
|
| 65 |
+
saliencyMap = (saliencyMap * 255).astype("uint8")
|
| 66 |
+
if colormap_name == "jet":
|
| 67 |
+
saliencyMap = cv2.applyColorMap(saliencyMap, cv2.COLORMAP_JET)
|
| 68 |
+
else:
|
| 69 |
+
saliencyMap = cv2.applyColorMap(saliencyMap, cv2.COLORMAP_HOT)
|
| 70 |
+
#return saliencyMap
|
| 71 |
+
# 入力画像とカラーマップを重ね合わせ
|
| 72 |
+
overlay = cv2.addWeighted(image, 0.5, saliencyMap, 0.5, 0)
|
| 73 |
+
#return overlay
|
| 74 |
+
|
| 75 |
+
return cv2.cvtColor(overlay, cv2.COLOR_BGR2RGBA)
|