letitbE commited on
Commit
a8aedde
·
1 Parent(s): fb78a64

feat: add demo gallery images and dataset-dir fallback for examples

Browse files
README.md CHANGED
@@ -19,5 +19,6 @@ Live demo: [Hugging Face Space](https://huggingface.co/spaces/letitbE/mine2real)
19
  - `notebooks/lab5.ipynb` — training notebook
20
  - `checkpoints/cycle_gan_color_fix#0_epoch_10.pt` — trained last checkpoint used by the app
21
  - `data/` — not in git (Space repo size limit); download the dataset from the link in the Space / course materials, then extract so you have `data/datasets/` (and optionally `data/datasets.tar.gz`) locally
 
22
  - `app.py` — streamlit app
23
  - `mine2real/` — model definition and inference utilities
 
19
  - `notebooks/lab5.ipynb` — training notebook
20
  - `checkpoints/cycle_gan_color_fix#0_epoch_10.pt` — trained last checkpoint used by the app
21
  - `data/` — not in git (Space repo size limit); download the dataset from the link in the Space / course materials, then extract so you have `data/datasets/` (and optionally `data/datasets.tar.gz`) locally
22
+ - `demo_gallery/` — small fixed subset (10 images) shipped in git for the Streamlit example gallery when the full dataset is absent
23
  - `app.py` — streamlit app
24
  - `mine2real/` — model definition and inference utilities
app.py CHANGED
@@ -113,19 +113,19 @@ def get_model():
113
  @st.cache_data(show_spinner=False)
114
  def load_example_gallery() -> dict[str, list[Path]]:
115
  return {
116
- "minecraft": list_example_images(MINECRAFT_DIR, limit=8, seed=7),
117
- "real": list_example_images(REAL_DIR, limit=8, seed=11),
118
  }
119
 
120
 
121
  def render_result(result: TranslationResult, input_label: str, output_label: str) -> None:
122
  col1, col2, col3 = st.columns(3)
123
- col1.image(result.source, caption=input_label, use_container_width=True)
124
- col2.image(result.translated, caption=output_label, use_container_width=True)
125
  col3.image(
126
  result.reconstructed,
127
  caption="Cycle reconstruction",
128
- use_container_width=True,
129
  )
130
 
131
 
@@ -209,7 +209,7 @@ def app() -> None:
209
  help="The app accepts exactly one image and runs a full CycleGAN pass with reconstruction.",
210
  )
211
 
212
- run_button = st.button("Translate image", type="primary", use_container_width=True)
213
 
214
  if uploaded is not None:
215
  image = Image.open(uploaded).convert("RGB")
@@ -243,7 +243,7 @@ def app() -> None:
243
  """
244
  )
245
  else:
246
- st.image(image, caption="Uploaded image", use_container_width=True)
247
  else:
248
  st.info("Upload an image and click `Translate image` to run the model.")
249
 
@@ -253,7 +253,8 @@ def app() -> None:
253
  """
254
  - Notebook: `notebooks/lab5.ipynb`
255
  - Checkpoint: `checkpoints/cycle_gan_color_fix#0_epoch_10.pt`
256
- - Dataset: not stored in this repo (HF Space storage limit); obtain the archive from the published link and extract to `data/datasets/` for local training and example galleries
 
257
  """
258
  )
259
 
@@ -265,7 +266,7 @@ def app() -> None:
265
  galleries = load_example_gallery()
266
  if not galleries["minecraft"] and not galleries["real"]:
267
  st.info(
268
- "Example gallery is empty: add `data/datasets/` (see README) or the dataset from your course link."
269
  )
270
  block_a, block_b = st.columns(2)
271
 
@@ -282,7 +283,7 @@ def app() -> None:
282
  st.write("")
283
  cols = st.columns(2)
284
  for idx, image_path in enumerate(galleries["minecraft"]):
285
- cols[idx % 2].image(str(image_path), caption=image_path.name, use_container_width=True)
286
 
287
  with block_b:
288
  st.markdown(
@@ -297,7 +298,7 @@ def app() -> None:
297
  st.write("")
298
  cols = st.columns(2)
299
  for idx, image_path in enumerate(galleries["real"]):
300
- cols[idx % 2].image(str(image_path), caption=image_path.name, use_container_width=True)
301
 
302
 
303
  if __name__ == "__main__":
 
113
  @st.cache_data(show_spinner=False)
114
  def load_example_gallery() -> dict[str, list[Path]]:
115
  return {
116
+ "minecraft": list_example_images(MINECRAFT_DIR, limit=5, seed=7),
117
+ "real": list_example_images(REAL_DIR, limit=5, seed=11),
118
  }
119
 
120
 
121
  def render_result(result: TranslationResult, input_label: str, output_label: str) -> None:
122
  col1, col2, col3 = st.columns(3)
123
+ col1.image(result.source, caption=input_label, width="stretch")
124
+ col2.image(result.translated, caption=output_label, width="stretch")
125
  col3.image(
126
  result.reconstructed,
127
  caption="Cycle reconstruction",
128
+ width="stretch",
129
  )
130
 
131
 
 
209
  help="The app accepts exactly one image and runs a full CycleGAN pass with reconstruction.",
210
  )
211
 
212
+ run_button = st.button("Translate image", type="primary", width="stretch")
213
 
214
  if uploaded is not None:
215
  image = Image.open(uploaded).convert("RGB")
 
243
  """
244
  )
245
  else:
246
+ st.image(image, caption="Uploaded image", width="stretch")
247
  else:
248
  st.info("Upload an image and click `Translate image` to run the model.")
249
 
 
253
  """
254
  - Notebook: `notebooks/lab5.ipynb`
255
  - Checkpoint: `checkpoints/cycle_gan_color_fix#0_epoch_10.pt`
256
+ - Dataset link: [Yandex Disk](https://disk.yandex.ru/d/N_G-t-oirnLynw)
257
+ - Example gallery in repo: `demo_gallery/` (5+5 images); full training data stays under `data/datasets/` when unpacked locally
258
  """
259
  )
260
 
 
266
  galleries = load_example_gallery()
267
  if not galleries["minecraft"] and not galleries["real"]:
268
  st.info(
269
+ "Example gallery is empty: ensure `demo_gallery/` is present in the app checkout, or restore `data/datasets/` from the Yandex Disk link above."
270
  )
271
  block_a, block_b = st.columns(2)
272
 
 
283
  st.write("")
284
  cols = st.columns(2)
285
  for idx, image_path in enumerate(galleries["minecraft"]):
286
+ cols[idx % 2].image(str(image_path), caption=image_path.name, width="stretch")
287
 
288
  with block_b:
289
  st.markdown(
 
298
  st.write("")
299
  cols = st.columns(2)
300
  for idx, image_path in enumerate(galleries["real"]):
301
+ cols[idx % 2].image(str(image_path), caption=image_path.name, width="stretch")
302
 
303
 
304
  if __name__ == "__main__":
demo_gallery/minecraft/flowers1_00061.jpg ADDED

Git LFS Details

  • SHA256: 8d57bef9ef06a00210931269b3fe2b6d5b9c1ddd1260ba5d0506d273896b66c4
  • Pointer size: 131 Bytes
  • Size of remote file: 352 kB
demo_gallery/minecraft/flowers1_00091.jpg ADDED

Git LFS Details

  • SHA256: 31748c4e42df6d59bb700bb39d35a546d78ea9f4061bdbf30b10d71825140cae
  • Pointer size: 131 Bytes
  • Size of remote file: 305 kB
demo_gallery/minecraft/flowers1_00121.jpg ADDED

Git LFS Details

  • SHA256: a4f11371cb0144c442a57e3a4cbb6ac872c1515ffe78ed9cba30c2b49f939bbf
  • Pointer size: 131 Bytes
  • Size of remote file: 287 kB
demo_gallery/minecraft/flowers1_00151.jpg ADDED

Git LFS Details

  • SHA256: af21cabebec6edc87d0f3b96e19c6131fe18afae18317b5678fee3248b09ebef
  • Pointer size: 131 Bytes
  • Size of remote file: 255 kB
demo_gallery/minecraft/flowers1_00181.jpg ADDED

Git LFS Details

  • SHA256: 58df08584afeec0085fc83b8433f7fe9f78c24cd33dd03b2964fdd34e6fef6b5
  • Pointer size: 131 Bytes
  • Size of remote file: 284 kB
demo_gallery/real/image_0.png ADDED

Git LFS Details

  • SHA256: 72636fb36edf6f45b528f132bcba7d0bb43615c481bc95ca5832c299da7caa36
  • Pointer size: 131 Bytes
  • Size of remote file: 383 kB
demo_gallery/real/image_1.png ADDED

Git LFS Details

  • SHA256: df07561174a00e095e482e27a0d42680dcf72e4a5424b6443171a296f700f6c4
  • Pointer size: 131 Bytes
  • Size of remote file: 436 kB
demo_gallery/real/image_10.png ADDED

Git LFS Details

  • SHA256: c94bfef085c8e058973f1fa2d8f86d9e75c7471435b16ad0d17d32144981af8e
  • Pointer size: 131 Bytes
  • Size of remote file: 424 kB
demo_gallery/real/image_100.png ADDED

Git LFS Details

  • SHA256: e9a6578522d26a9421f77995471dc1899ef1e893cc00b5efef2c12cc7cee5c6d
  • Pointer size: 131 Bytes
  • Size of remote file: 373 kB
demo_gallery/real/image_1000.png ADDED

Git LFS Details

  • SHA256: dfe6315ec8409527b23f907bf209763630a01fe0085cc5e237c788dd2565e85f
  • Pointer size: 131 Bytes
  • Size of remote file: 441 kB
mine2real/inference.py CHANGED
@@ -7,15 +7,16 @@ import random
7
  import numpy as np
8
  from PIL import Image, ImageOps
9
  import torch
10
- from torchvision import transforms as T
11
 
12
  from mine2real.model import CycleGAN
13
 
14
  IMAGE_SIZE = 256
15
  CHECKPOINT_PATH = Path("checkpoints/cycle_gan_color_fix#0_epoch_10.pt")
16
- DATASET_ROOT = Path("data/datasets")
17
- MINECRAFT_DIR = DATASET_ROOT / "minecraft_forestlike"
18
- REAL_DIR = DATASET_ROOT / "real_nature_landscape"
 
 
19
 
20
 
21
  @dataclass(frozen=True)
@@ -34,14 +35,11 @@ def get_device() -> torch.device:
34
  return torch.device("cpu")
35
 
36
 
37
- def build_transform(image_size: int = IMAGE_SIZE) -> T.Compose:
38
- return T.Compose(
39
- [
40
- T.Resize((image_size, image_size)),
41
- T.ToTensor(),
42
- T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
43
- ]
44
- )
45
 
46
 
47
  def tensor_to_pil(image_tensor: torch.Tensor) -> Image.Image:
@@ -55,7 +53,7 @@ def tensor_to_pil(image_tensor: torch.Tensor) -> Image.Image:
55
  def prepare_image(image: Image.Image, image_size: int = IMAGE_SIZE) -> tuple[Image.Image, torch.Tensor]:
56
  source = image.convert("RGB")
57
  source = ImageOps.fit(source, (image_size, image_size), method=Image.Resampling.LANCZOS)
58
- tensor = build_transform(image_size)(source).unsqueeze(0)
59
  return source, tensor
60
 
61
 
 
7
  import numpy as np
8
  from PIL import Image, ImageOps
9
  import torch
 
10
 
11
  from mine2real.model import CycleGAN
12
 
13
  IMAGE_SIZE = 256
14
  CHECKPOINT_PATH = Path("checkpoints/cycle_gan_color_fix#0_epoch_10.pt")
15
+ _ROOT = Path(__file__).resolve().parent.parent
16
+ _full_mc = _ROOT / "data" / "datasets" / "minecraft_forestlike"
17
+ _full_real = _ROOT / "data" / "datasets" / "real_nature_landscape"
18
+ MINECRAFT_DIR = _full_mc if _full_mc.is_dir() else _ROOT / "demo_gallery" / "minecraft"
19
+ REAL_DIR = _full_real if _full_real.is_dir() else _ROOT / "demo_gallery" / "real"
20
 
21
 
22
  @dataclass(frozen=True)
 
35
  return torch.device("cpu")
36
 
37
 
38
+ def pil_to_normalized_tensor(image: Image.Image) -> torch.Tensor:
39
+ array = np.asarray(image, dtype=np.float32) / 255.0
40
+ tensor = torch.from_numpy(array).permute(2, 0, 1)
41
+ tensor = (tensor - 0.5) / 0.5
42
+ return tensor
 
 
 
43
 
44
 
45
  def tensor_to_pil(image_tensor: torch.Tensor) -> Image.Image:
 
53
  def prepare_image(image: Image.Image, image_size: int = IMAGE_SIZE) -> tuple[Image.Image, torch.Tensor]:
54
  source = image.convert("RGB")
55
  source = ImageOps.fit(source, (image_size, image_size), method=Image.Resampling.LANCZOS)
56
+ tensor = pil_to_normalized_tensor(source).unsqueeze(0)
57
  return source, tensor
58
 
59
 
requirements.txt CHANGED
@@ -1,4 +1,3 @@
1
  streamlit>=1.40,<2
2
  numpy>=1.26,<3
3
  pillow>=10,<12
4
- torchvision>=0.20,<0.23
 
1
  streamlit>=1.40,<2
2
  numpy>=1.26,<3
3
  pillow>=10,<12