Files changed (1) hide show
  1. app.py +30 -2
app.py CHANGED
@@ -4,7 +4,7 @@ from tensorflow.keras.preprocessing.image import ImageDataGenerator
4
  from tensorflow.keras.models import Sequential
5
  from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
6
  import numpy as np
7
- from PIL import Image
8
  import os
9
 
10
  # -----------------------------
@@ -16,10 +16,38 @@ IMG_SIZE = (128, 128)
16
  BATCH_SIZE = 32
17
  EPOCHS = 5
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  # -----------------------------
20
  # TRAIN MODEL FUNCTION
21
  # -----------------------------
22
  def train_model():
 
 
 
23
  datagen = ImageDataGenerator(
24
  rescale=1./255,
25
  validation_split=0.2
@@ -70,7 +98,7 @@ def train_model():
70
  return model, list(train_data.class_indices.keys())
71
 
72
  # -----------------------------
73
- # LOAD MODEL
74
  # -----------------------------
75
  if not os.path.exists(MODEL_PATH):
76
  st.warning("Training model for first-time use. Please wait...")
 
4
  from tensorflow.keras.models import Sequential
5
  from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
6
  import numpy as np
7
+ from PIL import Image, UnidentifiedImageError
8
  import os
9
 
10
  # -----------------------------
 
16
  BATCH_SIZE = 32
17
  EPOCHS = 5
18
 
19
+ # -----------------------------
20
+ # REMOVE CORRUPTED IMAGES
21
+ # -----------------------------
22
+ def clean_dataset(dataset_path):
23
+ valid_extensions = (".jpg", ".jpeg", ".png")
24
+
25
+ removed = 0
26
+ for root, dirs, files in os.walk(dataset_path):
27
+ for file in files:
28
+ file_path = os.path.join(root, file)
29
+
30
+ if not file.lower().endswith(valid_extensions):
31
+ os.remove(file_path)
32
+ removed += 1
33
+ continue
34
+
35
+ try:
36
+ with Image.open(file_path) as img:
37
+ img.verify()
38
+ except (UnidentifiedImageError, OSError):
39
+ os.remove(file_path)
40
+ removed += 1
41
+
42
+ return removed
43
+
44
  # -----------------------------
45
  # TRAIN MODEL FUNCTION
46
  # -----------------------------
47
  def train_model():
48
+ removed_files = clean_dataset(DATASET_DIR)
49
+ st.info(f"Removed {removed_files} corrupted/invalid files.")
50
+
51
  datagen = ImageDataGenerator(
52
  rescale=1./255,
53
  validation_split=0.2
 
98
  return model, list(train_data.class_indices.keys())
99
 
100
  # -----------------------------
101
+ # LOAD OR TRAIN MODEL
102
  # -----------------------------
103
  if not os.path.exists(MODEL_PATH):
104
  st.warning("Training model for first-time use. Please wait...")