| """This is the code for only demo in hugging face. The mina code can be run as it is mentioned in readme.""" |
|
|
| from dataclasses import dataclass, field |
| from typing import Any, Dict, Optional |
| from pathlib import Path |
|
|
| import shutil |
| import tempfile |
| import gradio as gr |
|
|
| from src.panaroma_stitcher.kornia import KorniaStitcher |
| from src.panaroma_stitcher.opencv_simple import SimpleStitcher |
| from src.panaroma_stitcher.keypoint_stitcher import KeypointStitcher |
| from src.panaroma_stitcher.detailed_stitcher import DetailedStitcher |
| from src.panaroma_stitcher.sequential_stitcher import SequentialStitcher |
|
|
|
|
| @dataclass |
| class StitcherDemo: |
| """This is a simple class for implementing demo in hugging face""" |
|
|
| parameters: Dict[str, Any] = field(init=False, default_factory=dict) |
| param_values: Dict[str, Any] = field(init=False, default_factory=dict) |
|
|
| def __post_init__(self) -> None: |
| """initialize the parameters""" |
| self.param_values["model"] = "Simple Stitcher" |
| self.param_values["stitcher_type"] = "panorama" |
| self.param_values["detect_method"] = "sift" |
| self.param_values["match_type"] = "homography" |
| self.param_values["num_feat"] = 500 |
| self.param_values["conf_thr"] = 0.05 |
| self.param_values["cam_est"] = "homography" |
| self.param_values["cam_adj"] = "ray" |
| self.param_values["matching_method"] = "bf" |
| self.param_values["detector_method"] = "sift" |
| self.param_values["number_feature"] = 500 |
| self.param_values["method"] = "loftr" |
| self.param_values["loftr_model"] = "outdoor" |
| self.param_values["features"] = 100 |
| self.param_values["matcher"] = "smnn" |
| self.param_values["thr"] = 0.8 |
|
|
| def temp_dir(self, files: Any) -> str: |
| """create temp folder for uploading the images in gradio""" |
| temp_dir = Path(tempfile.gettempdir()) / "uploaded_images" |
| if temp_dir.exists(): |
| shutil.rmtree(temp_dir) |
| temp_dir.mkdir(parents=True, exist_ok=True) |
| for file in files: |
| file_name = Path(file).name |
| dest_path = temp_dir / file_name |
| shutil.move(file, str(dest_path)) |
| return str(temp_dir) |
|
|
| def callback(self, files: Any) -> Optional[Any]: |
| """Callback function to be used within gradio""" |
| print(self.param_values) |
| input_dir = self.temp_dir(files) |
| if self.param_values["model"] == "Simple Stitcher": |
| stitcher1 = SimpleStitcher( |
| image_dir=Path(input_dir), |
| stitcher_type=self.param_values["stitcher_type"], |
| ) |
| return stitcher1.stitcher() |
| if self.param_values["model"] == "Detailed Stitcher": |
| stitcher2 = DetailedStitcher( |
| image_dir=Path(input_dir), |
| feature_number=self.param_values["num_feat"], |
| device="cpu", |
| detector_method=self.param_values["detect_method"], |
| matcher_type=self.param_values["match_type"], |
| confidence_threshold=self.param_values["conf_thr"], |
| camera_adjustor=self.param_values["cam_adj"], |
| camera_estimator=self.param_values["cam_est"], |
| ) |
| return stitcher2.stitcher() |
| if self.param_values["model"] == "Kornia Stitcher": |
| stitcher3 = KorniaStitcher(image_dir=Path(input_dir)) |
| if self.param_values["method"] == "loftr": |
| stitcher3.loftr_matcher(model=self.param_values["loftr_model"]) |
| if self.param_values["method"] == "local": |
| stitcher3.local_matcher( |
| number_of_features=self.param_values["features"], |
| match_mode=self.param_values["matcher"], |
| thr=self.param_values["thr"], |
| ) |
| if self.param_values["method"] == "keynote": |
| stitcher3.keynote_matcher( |
| number_of_features=self.param_values["features"], |
| match_mode=self.param_values["matcher"], |
| thr=self.param_values["thr"], |
| ) |
| return stitcher3.stitcher() |
| if self.param_values["model"] == "Sequential Stitcher": |
| stitcher4 = SequentialStitcher( |
| image_dir=Path(input_dir), |
| feature_detector=self.param_values["detector_method"], |
| matcher_type=self.param_values["matching_method"], |
| number_feature=self.param_values["number_feature"], |
| final_size=(1000, 3000), |
| ) |
| return stitcher4.stitcher() |
| stitcher5 = KeypointStitcher( |
| image_dir=Path(input_dir), |
| feature_detector=self.param_values["detector_method"], |
| matcher_type=self.param_values["matching_method"], |
| number_feature=self.param_values["number_feature"], |
| ) |
| return stitcher5.stitcher() |
|
|
| def _design_simple_parameter(self) -> None: |
| """Design the simple stitcher parameter section in gradio""" |
| with gr.Accordion("Simple Stitcher Parameters", open=False): |
| with gr.Row(): |
| self.parameters["stitcher_type"] = gr.Radio( |
| ["scan", "panorama"], |
| value=self.param_values["stitcher_type"], |
| label="Stitcher type", |
| ) |
|
|
| def _design_detailed_parameter(self) -> None: |
| """Design the detailed stitcher parameter section in gradio""" |
| with gr.Accordion("Detailed Stitcher Parameters", open=False): |
| with gr.Row(): |
| self.parameters["detect_method"] = gr.Radio( |
| ["sift", "orb", "brisk", "akaze"], |
| value=self.param_values["detect_method"], |
| label="Detection method", |
| ) |
| self.parameters["match_type"] = gr.Radio( |
| ["affine", "homography"], |
| value=self.param_values["match_type"], |
| label="Matching method", |
| ) |
| self.parameters["num_feat"] = gr.Number( |
| value=self.param_values["num_feat"], |
| precision=0, |
| label="No. of feature", |
| ) |
| self.parameters["conf_thr"] = gr.Number( |
| value=self.param_values["conf_thr"], label="Confidence threshold" |
| ) |
| self.parameters["cam_est"] = gr.Radio( |
| ["affine", "homography"], |
| value=self.param_values["cam_est"], |
| label="Camera estimator", |
| ) |
| self.parameters["cam_adj"] = gr.Radio( |
| ["ray", "reproj", "affine", "no"], |
| value=self.param_values["cam_adj"], |
| label="Camera adjustor", |
| ) |
|
|
| def _design_keypoint_parameter(self) -> None: |
| """Design the keypoint/sequential stitcher parameter section in gradio""" |
| with gr.Accordion("Keypoint/Sequential Stitcher Parameters", open=False): |
| with gr.Row(): |
| self.parameters["matching_method"] = gr.Radio( |
| ["bf", "flann"], |
| value=self.param_values["matching_method"], |
| label="matching method", |
| ) |
| self.parameters["detector_method"] = gr.Radio( |
| ["sift", "orb", "brisk"], |
| value=self.param_values["detector_method"], |
| label="detecting method", |
| ) |
| self.parameters["number_feature"] = gr.Number( |
| value=self.param_values["number_feature"], |
| precision=0, |
| label="number of feature", |
| ) |
|
|
| def _design_kornia_parameter(self) -> None: |
| """Design the kornia stitcher parameter section in gradio""" |
| with gr.Accordion("Kornia Stitcher Parameters", open=False): |
| with gr.Row(): |
| self.parameters["method"] = gr.Radio( |
| ["loftr", "local", "keynote"], |
| value=self.param_values["method"], |
| label="Matching method", |
| ) |
| self.parameters["loftr_model"] = gr.Radio( |
| ["outdoor", "indoor"], |
| value=self.param_values["loftr_model"], |
| label="loftr model type", |
| ) |
| self.parameters["features"] = gr.Number( |
| value=self.param_values["features"], |
| precision=0, |
| label="No. of features in local/keynote methods", |
| ) |
| self.parameters["matcher"] = gr.Radio( |
| ["snn", "nn", "mnn", "smnn"], |
| value=self.param_values["matcher"], |
| label="Matcher mode in local/keynote methods.", |
| ) |
| self.parameters["thr"] = gr.Number( |
| value=self.param_values["thr"], |
| label="Threshold for local/keynote method", |
| ) |
|
|
| def dummy_logger(self, val: Any, key: str) -> None: |
| """Dummy logger to bes used in radio.change method""" |
| |
| if isinstance(val, gr.components.radio.Radio): |
| self.param_values[key] = str(val) |
| if isinstance(val, gr.components.number.Number): |
| if ((int(val) * 10) % 10) != 0: |
| self.param_values[key] = float(str(val)) |
| else: |
| self.param_values[key] = int(str(val)) |
|
|
| def update_radio_button(self) -> None: |
| """Need to update the radio options whenever we set default value for them""" |
| for key, val in self.parameters.items(): |
| val.change(self.dummy_logger, inputs=[val, gr.State(key)], outputs=None) |
|
|
| def demo(self) -> None: |
| """This is a design for the demo page""" |
| with gr.Blocks() as demo: |
| |
| with gr.Column(): |
| self.parameters["model"] = gr.Radio( |
| [ |
| "Simple Stitcher", |
| "Kornia Stitcher", |
| "Sequential Stitcher", |
| "Keypoint Stitcher", |
| "Detailed Stitcher", |
| ], |
| value=self.param_values["model"], |
| label="Select the stitcher type", |
| ) |
| self._design_simple_parameter() |
| self._design_kornia_parameter() |
| self._design_keypoint_parameter() |
| self._design_detailed_parameter() |
| self.update_radio_button() |
| files = gr.Files(file_types=["image"], file_count="multiple") |
| submit_btn = gr.Button(value="Stitch images") |
| with gr.Column(): |
| result = gr.Image(type="pil") |
| submit_btn.click( |
| self.callback, inputs=[files], outputs=result, api_name=False |
| ) |
| demo.launch() |
|
|
|
|
| stitching_demo = StitcherDemo() |
| stitching_demo.demo() |
|
|