feat: add demo gallery images and dataset-dir fallback for examples
Browse files- README.md +1 -0
- app.py +12 -11
- demo_gallery/minecraft/flowers1_00061.jpg +3 -0
- demo_gallery/minecraft/flowers1_00091.jpg +3 -0
- demo_gallery/minecraft/flowers1_00121.jpg +3 -0
- demo_gallery/minecraft/flowers1_00151.jpg +3 -0
- demo_gallery/minecraft/flowers1_00181.jpg +3 -0
- demo_gallery/real/image_0.png +3 -0
- demo_gallery/real/image_1.png +3 -0
- demo_gallery/real/image_10.png +3 -0
- demo_gallery/real/image_100.png +3 -0
- demo_gallery/real/image_1000.png +3 -0
- mine2real/inference.py +11 -13
- requirements.txt +0 -1
README.md
CHANGED
|
@@ -19,5 +19,6 @@ Live demo: [Hugging Face Space](https://huggingface.co/spaces/letitbE/mine2real)
|
|
| 19 |
- `notebooks/lab5.ipynb` — training notebook
|
| 20 |
- `checkpoints/cycle_gan_color_fix#0_epoch_10.pt` — trained last checkpoint used by the app
|
| 21 |
- `data/` — not in git (Space repo size limit); download the dataset from the link in the Space / course materials, then extract so you have `data/datasets/` (and optionally `data/datasets.tar.gz`) locally
|
|
|
|
| 22 |
- `app.py` — streamlit app
|
| 23 |
- `mine2real/` — model definition and inference utilities
|
|
|
|
| 19 |
- `notebooks/lab5.ipynb` — training notebook
|
| 20 |
- `checkpoints/cycle_gan_color_fix#0_epoch_10.pt` — trained last checkpoint used by the app
|
| 21 |
- `data/` — not in git (Space repo size limit); download the dataset from the link in the Space / course materials, then extract so you have `data/datasets/` (and optionally `data/datasets.tar.gz`) locally
|
| 22 |
+
- `demo_gallery/` — small fixed subset (10 images) shipped in git for the Streamlit example gallery when the full dataset is absent
|
| 23 |
- `app.py` — streamlit app
|
| 24 |
- `mine2real/` — model definition and inference utilities
|
app.py
CHANGED
|
@@ -113,19 +113,19 @@ def get_model():
|
|
| 113 |
@st.cache_data(show_spinner=False)
|
| 114 |
def load_example_gallery() -> dict[str, list[Path]]:
|
| 115 |
return {
|
| 116 |
-
"minecraft": list_example_images(MINECRAFT_DIR, limit=
|
| 117 |
-
"real": list_example_images(REAL_DIR, limit=
|
| 118 |
}
|
| 119 |
|
| 120 |
|
| 121 |
def render_result(result: TranslationResult, input_label: str, output_label: str) -> None:
|
| 122 |
col1, col2, col3 = st.columns(3)
|
| 123 |
-
col1.image(result.source, caption=input_label,
|
| 124 |
-
col2.image(result.translated, caption=output_label,
|
| 125 |
col3.image(
|
| 126 |
result.reconstructed,
|
| 127 |
caption="Cycle reconstruction",
|
| 128 |
-
|
| 129 |
)
|
| 130 |
|
| 131 |
|
|
@@ -209,7 +209,7 @@ def app() -> None:
|
|
| 209 |
help="The app accepts exactly one image and runs a full CycleGAN pass with reconstruction.",
|
| 210 |
)
|
| 211 |
|
| 212 |
-
run_button = st.button("Translate image", type="primary",
|
| 213 |
|
| 214 |
if uploaded is not None:
|
| 215 |
image = Image.open(uploaded).convert("RGB")
|
|
@@ -243,7 +243,7 @@ def app() -> None:
|
|
| 243 |
"""
|
| 244 |
)
|
| 245 |
else:
|
| 246 |
-
st.image(image, caption="Uploaded image",
|
| 247 |
else:
|
| 248 |
st.info("Upload an image and click `Translate image` to run the model.")
|
| 249 |
|
|
@@ -253,7 +253,8 @@ def app() -> None:
|
|
| 253 |
"""
|
| 254 |
- Notebook: `notebooks/lab5.ipynb`
|
| 255 |
- Checkpoint: `checkpoints/cycle_gan_color_fix#0_epoch_10.pt`
|
| 256 |
-
- Dataset
|
|
|
|
| 257 |
"""
|
| 258 |
)
|
| 259 |
|
|
@@ -265,7 +266,7 @@ def app() -> None:
|
|
| 265 |
galleries = load_example_gallery()
|
| 266 |
if not galleries["minecraft"] and not galleries["real"]:
|
| 267 |
st.info(
|
| 268 |
-
"Example gallery is empty:
|
| 269 |
)
|
| 270 |
block_a, block_b = st.columns(2)
|
| 271 |
|
|
@@ -282,7 +283,7 @@ def app() -> None:
|
|
| 282 |
st.write("")
|
| 283 |
cols = st.columns(2)
|
| 284 |
for idx, image_path in enumerate(galleries["minecraft"]):
|
| 285 |
-
cols[idx % 2].image(str(image_path), caption=image_path.name,
|
| 286 |
|
| 287 |
with block_b:
|
| 288 |
st.markdown(
|
|
@@ -297,7 +298,7 @@ def app() -> None:
|
|
| 297 |
st.write("")
|
| 298 |
cols = st.columns(2)
|
| 299 |
for idx, image_path in enumerate(galleries["real"]):
|
| 300 |
-
cols[idx % 2].image(str(image_path), caption=image_path.name,
|
| 301 |
|
| 302 |
|
| 303 |
if __name__ == "__main__":
|
|
|
|
| 113 |
@st.cache_data(show_spinner=False)
|
| 114 |
def load_example_gallery() -> dict[str, list[Path]]:
|
| 115 |
return {
|
| 116 |
+
"minecraft": list_example_images(MINECRAFT_DIR, limit=5, seed=7),
|
| 117 |
+
"real": list_example_images(REAL_DIR, limit=5, seed=11),
|
| 118 |
}
|
| 119 |
|
| 120 |
|
| 121 |
def render_result(result: TranslationResult, input_label: str, output_label: str) -> None:
|
| 122 |
col1, col2, col3 = st.columns(3)
|
| 123 |
+
col1.image(result.source, caption=input_label, width="stretch")
|
| 124 |
+
col2.image(result.translated, caption=output_label, width="stretch")
|
| 125 |
col3.image(
|
| 126 |
result.reconstructed,
|
| 127 |
caption="Cycle reconstruction",
|
| 128 |
+
width="stretch",
|
| 129 |
)
|
| 130 |
|
| 131 |
|
|
|
|
| 209 |
help="The app accepts exactly one image and runs a full CycleGAN pass with reconstruction.",
|
| 210 |
)
|
| 211 |
|
| 212 |
+
run_button = st.button("Translate image", type="primary", width="stretch")
|
| 213 |
|
| 214 |
if uploaded is not None:
|
| 215 |
image = Image.open(uploaded).convert("RGB")
|
|
|
|
| 243 |
"""
|
| 244 |
)
|
| 245 |
else:
|
| 246 |
+
st.image(image, caption="Uploaded image", width="stretch")
|
| 247 |
else:
|
| 248 |
st.info("Upload an image and click `Translate image` to run the model.")
|
| 249 |
|
|
|
|
| 253 |
"""
|
| 254 |
- Notebook: `notebooks/lab5.ipynb`
|
| 255 |
- Checkpoint: `checkpoints/cycle_gan_color_fix#0_epoch_10.pt`
|
| 256 |
+
- Dataset link: [Yandex Disk](https://disk.yandex.ru/d/N_G-t-oirnLynw)
|
| 257 |
+
- Example gallery in repo: `demo_gallery/` (5+5 images); full training data stays under `data/datasets/` when unpacked locally
|
| 258 |
"""
|
| 259 |
)
|
| 260 |
|
|
|
|
| 266 |
galleries = load_example_gallery()
|
| 267 |
if not galleries["minecraft"] and not galleries["real"]:
|
| 268 |
st.info(
|
| 269 |
+
"Example gallery is empty: ensure `demo_gallery/` is present in the app checkout, or restore `data/datasets/` from the Yandex Disk link above."
|
| 270 |
)
|
| 271 |
block_a, block_b = st.columns(2)
|
| 272 |
|
|
|
|
| 283 |
st.write("")
|
| 284 |
cols = st.columns(2)
|
| 285 |
for idx, image_path in enumerate(galleries["minecraft"]):
|
| 286 |
+
cols[idx % 2].image(str(image_path), caption=image_path.name, width="stretch")
|
| 287 |
|
| 288 |
with block_b:
|
| 289 |
st.markdown(
|
|
|
|
| 298 |
st.write("")
|
| 299 |
cols = st.columns(2)
|
| 300 |
for idx, image_path in enumerate(galleries["real"]):
|
| 301 |
+
cols[idx % 2].image(str(image_path), caption=image_path.name, width="stretch")
|
| 302 |
|
| 303 |
|
| 304 |
if __name__ == "__main__":
|
demo_gallery/minecraft/flowers1_00061.jpg
ADDED
|
Git LFS Details
|
demo_gallery/minecraft/flowers1_00091.jpg
ADDED
|
Git LFS Details
|
demo_gallery/minecraft/flowers1_00121.jpg
ADDED
|
Git LFS Details
|
demo_gallery/minecraft/flowers1_00151.jpg
ADDED
|
Git LFS Details
|
demo_gallery/minecraft/flowers1_00181.jpg
ADDED
|
Git LFS Details
|
demo_gallery/real/image_0.png
ADDED
|
Git LFS Details
|
demo_gallery/real/image_1.png
ADDED
|
Git LFS Details
|
demo_gallery/real/image_10.png
ADDED
|
Git LFS Details
|
demo_gallery/real/image_100.png
ADDED
|
Git LFS Details
|
demo_gallery/real/image_1000.png
ADDED
|
Git LFS Details
|
mine2real/inference.py
CHANGED
|
@@ -7,15 +7,16 @@ import random
|
|
| 7 |
import numpy as np
|
| 8 |
from PIL import Image, ImageOps
|
| 9 |
import torch
|
| 10 |
-
from torchvision import transforms as T
|
| 11 |
|
| 12 |
from mine2real.model import CycleGAN
|
| 13 |
|
| 14 |
IMAGE_SIZE = 256
|
| 15 |
CHECKPOINT_PATH = Path("checkpoints/cycle_gan_color_fix#0_epoch_10.pt")
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
|
|
|
|
|
|
| 19 |
|
| 20 |
|
| 21 |
@dataclass(frozen=True)
|
|
@@ -34,14 +35,11 @@ def get_device() -> torch.device:
|
|
| 34 |
return torch.device("cpu")
|
| 35 |
|
| 36 |
|
| 37 |
-
def
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
| 43 |
-
]
|
| 44 |
-
)
|
| 45 |
|
| 46 |
|
| 47 |
def tensor_to_pil(image_tensor: torch.Tensor) -> Image.Image:
|
|
@@ -55,7 +53,7 @@ def tensor_to_pil(image_tensor: torch.Tensor) -> Image.Image:
|
|
| 55 |
def prepare_image(image: Image.Image, image_size: int = IMAGE_SIZE) -> tuple[Image.Image, torch.Tensor]:
|
| 56 |
source = image.convert("RGB")
|
| 57 |
source = ImageOps.fit(source, (image_size, image_size), method=Image.Resampling.LANCZOS)
|
| 58 |
-
tensor =
|
| 59 |
return source, tensor
|
| 60 |
|
| 61 |
|
|
|
|
| 7 |
import numpy as np
|
| 8 |
from PIL import Image, ImageOps
|
| 9 |
import torch
|
|
|
|
| 10 |
|
| 11 |
from mine2real.model import CycleGAN
|
| 12 |
|
| 13 |
IMAGE_SIZE = 256
|
| 14 |
CHECKPOINT_PATH = Path("checkpoints/cycle_gan_color_fix#0_epoch_10.pt")
|
| 15 |
+
_ROOT = Path(__file__).resolve().parent.parent
|
| 16 |
+
_full_mc = _ROOT / "data" / "datasets" / "minecraft_forestlike"
|
| 17 |
+
_full_real = _ROOT / "data" / "datasets" / "real_nature_landscape"
|
| 18 |
+
MINECRAFT_DIR = _full_mc if _full_mc.is_dir() else _ROOT / "demo_gallery" / "minecraft"
|
| 19 |
+
REAL_DIR = _full_real if _full_real.is_dir() else _ROOT / "demo_gallery" / "real"
|
| 20 |
|
| 21 |
|
| 22 |
@dataclass(frozen=True)
|
|
|
|
| 35 |
return torch.device("cpu")
|
| 36 |
|
| 37 |
|
| 38 |
+
def pil_to_normalized_tensor(image: Image.Image) -> torch.Tensor:
|
| 39 |
+
array = np.asarray(image, dtype=np.float32) / 255.0
|
| 40 |
+
tensor = torch.from_numpy(array).permute(2, 0, 1)
|
| 41 |
+
tensor = (tensor - 0.5) / 0.5
|
| 42 |
+
return tensor
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
|
| 45 |
def tensor_to_pil(image_tensor: torch.Tensor) -> Image.Image:
|
|
|
|
| 53 |
def prepare_image(image: Image.Image, image_size: int = IMAGE_SIZE) -> tuple[Image.Image, torch.Tensor]:
|
| 54 |
source = image.convert("RGB")
|
| 55 |
source = ImageOps.fit(source, (image_size, image_size), method=Image.Resampling.LANCZOS)
|
| 56 |
+
tensor = pil_to_normalized_tensor(source).unsqueeze(0)
|
| 57 |
return source, tensor
|
| 58 |
|
| 59 |
|
requirements.txt
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
streamlit>=1.40,<2
|
| 2 |
numpy>=1.26,<3
|
| 3 |
pillow>=10,<12
|
| 4 |
-
torchvision>=0.20,<0.23
|
|
|
|
| 1 |
streamlit>=1.40,<2
|
| 2 |
numpy>=1.26,<3
|
| 3 |
pillow>=10,<12
|
|
|