Charuka66 commited on
Commit
96b923d
·
verified ·
1 Parent(s): d6c6958

Update augmentation script to scale Teacher dataset to 900 images

Browse files

Scaled the base dataset target from 100 to 300 images per class (Blast, Brown Spot, Sheath Blight) to resolve underfitting and poor accuracy (35%) in the initial Teacher model.
Retained geometric transformations (horizontal/vertical flips) with precise YOLO polygon coordinate inversions.
Retained photometric transformations (Gaussian noise and brightness scaling) to improve model robustness against varied environmental drone lighting.

Files changed (1) hide show
  1. augment.py +79 -92
augment.py CHANGED
@@ -6,143 +6,130 @@ from glob import glob
6
  from tqdm import tqdm
7
 
8
  # ================= CONFIGURATION =================
9
- # 1. Path to your Seed Dataset (The 36 images you just labelled)
10
- IMAGES_DIR = r"C:\Users\charu\Desktop\My_Project\seed_images"
11
- LABELS_DIR = r"C:\Users\charu\Desktop\My_Project\seed_labels"
12
 
13
- # 2. Target count per class (You wanted 100)
14
- TARGET_COUNT = 100
15
 
16
- # 3. Your Classes
17
- # Note: "Healthy" is identified by empty text files
18
- CLASS_NAMES = {0: "Blast", 1: "Brown Spot", 2: "Sheath Blight", "Healthy": "Healthy"}
19
  # =================================================
20
 
21
- def load_data():
22
- """Sorts images into lists based on what disease they contain."""
23
- data_map = {0: [], 1: [], 2: [], "Healthy": []}
24
 
25
- # Get all text files
26
  txt_files = glob(os.path.join(LABELS_DIR, "*.txt"))
27
 
 
 
 
 
 
 
 
28
  for txt_path in txt_files:
29
  filename = os.path.basename(txt_path).replace('.txt', '')
30
 
31
- # Find matching image
32
  img_jpg = os.path.join(IMAGES_DIR, filename + ".jpg")
33
  img_png = os.path.join(IMAGES_DIR, filename + ".png")
 
34
 
35
  if os.path.exists(img_jpg): img_path = img_jpg
36
  elif os.path.exists(img_png): img_path = img_png
37
- else: continue # Skip if image missing
 
 
 
38
 
39
- # Read the label
40
  with open(txt_path, 'r') as f:
41
  lines = f.readlines()
42
-
43
- # CLASSIFY THE IMAGE
44
- if not lines:
45
- # Empty file = Healthy
46
- data_map["Healthy"].append((img_path, lines))
47
- else:
48
- # Check the first class ID in the file
49
- class_id = int(lines[0].split()[0])
50
- if class_id in data_map:
51
- data_map[class_id].append((img_path, lines))
52
-
53
- return data_map
54
-
55
- def augment_and_save(img_path, lines, new_name):
56
  img = cv2.imread(img_path)
57
  if img is None: return
58
-
59
- # Random Augmentation Strategy
60
- aug_type = random.choice(["h_flip", "v_flip", "noise", "bright", "dark"])
61
  new_lines = []
62
 
63
- h, w, _ = img.shape
64
-
65
- # 1. Horizontal Flip
66
- if aug_type == "h_flip":
67
  new_img = cv2.flip(img, 1)
68
  for line in lines:
69
  parts = line.strip().split()
70
  cls = parts[0]
71
  coords = [float(x) for x in parts[1:]]
72
- # Math: New X = 1.0 - Old X
73
- new_coords = [1.0 - val if i % 2 == 0 else val for i, val in enumerate(coords)]
 
 
74
  new_lines.append(f"{cls} " + " ".join([f"{c:.6f}" for c in new_coords]) + "\n")
75
 
76
- # 2. Vertical Flip
77
- elif aug_type == "v_flip":
78
  new_img = cv2.flip(img, 0)
79
  for line in lines:
80
  parts = line.strip().split()
81
  cls = parts[0]
82
  coords = [float(x) for x in parts[1:]]
83
- # Math: New Y = 1.0 - Old Y
84
- new_coords = [1.0 - val if i % 2 != 0 else val for i, val in enumerate(coords)]
 
 
85
  new_lines.append(f"{cls} " + " ".join([f"{c:.6f}" for c in new_coords]) + "\n")
86
 
87
- # 3. Noise (Grainy)
88
- elif aug_type == "noise":
89
- noise = np.random.normal(0, 25, img.shape).astype(np.uint8)
90
- new_img = cv2.add(img, noise)
91
- new_lines = lines # Coordinates don't change
92
-
93
- # 4. Brightness
94
- elif aug_type == "bright":
95
- new_img = cv2.convertScaleAbs(img, alpha=1.1, beta=20)
96
- new_lines = lines
97
 
98
- # 5. Darkness
99
- elif aug_type == "dark":
100
- new_img = cv2.convertScaleAbs(img, alpha=0.9, beta=-20)
 
 
 
 
101
  new_lines = lines
102
 
103
- # Save Image
104
- cv2.imwrite(os.path.join(IMAGES_DIR, new_name + ".jpg"), new_img)
105
-
106
- # Save Label
107
- with open(os.path.join(LABELS_DIR, new_name + ".txt"), 'w') as f:
108
  f.writelines(new_lines)
109
 
110
  def main():
111
- print("🚀 scanning seed dataset...")
112
- data_map = load_data()
113
 
114
- print("\n📊 Current counts:")
115
- for key, items in data_map.items():
116
- name = CLASS_NAMES[key]
117
- print(f" - {name}: {len(items)} images")
118
-
119
- print(f"\n🛠️ Augmenting to reach {TARGET_COUNT} per class...")
120
-
121
- for key, items in data_map.items():
122
- name = CLASS_NAMES[key]
123
- current_count = len(items)
124
- needed = TARGET_COUNT - current_count
125
 
126
- if needed <= 0:
127
- print(f"✅ {name} is already full. Skipping.")
128
- continue
129
-
130
- if current_count == 0:
131
- print(f"⚠️ Warning: No images found for {name}. Cannot augment!")
132
- continue
133
-
134
- print(f" -> Generating {needed} images for {name}...")
135
 
136
- for i in tqdm(range(needed)):
137
- # Pick a random source image to clone
138
- source_img, source_lines = random.choice(items)
139
-
140
- # Create unique name
141
- new_filename = f"aug_{key}_{i}"
142
-
143
- augment_and_save(source_img, source_lines, new_filename)
144
-
145
- print("\n✅ Augmentation Complete! You now have 400 images.")
146
 
147
  if __name__ == "__main__":
148
  main()
 
6
  from tqdm import tqdm
7
 
8
  # ================= CONFIGURATION =================
9
+ # 1. PATHS (Separated)
10
+ IMAGES_DIR = r"C:\Users\charu\Desktop\04-02-2026\images"
11
+ LABELS_DIR = r"C:\Users\charu\Desktop\04-02-2026\labels"
12
 
13
+ # 2. Target Count per class
14
+ TARGET_PER_CLASS = 300
15
 
16
+ # 3. Class Names
17
+ CLASS_NAMES = {0: "Blast", 1: "Brown Spot", 2: "Sheath Blight"}
 
18
  # =================================================
19
 
20
+ def load_dataset():
21
+ dataset = {0: [], 1: [], 2: []}
 
22
 
23
+ # scan labels folder
24
  txt_files = glob(os.path.join(LABELS_DIR, "*.txt"))
25
 
26
+ print(f"📂 Scanning Labels: {LABELS_DIR}")
27
+ print(f" -> Found {len(txt_files)} text files.")
28
+
29
+ if len(txt_files) == 0:
30
+ print(" Error: No text files found! Check the path.")
31
+ return dataset
32
+
33
  for txt_path in txt_files:
34
  filename = os.path.basename(txt_path).replace('.txt', '')
35
 
36
+ # Look for matching image in IMAGES_DIR
37
  img_jpg = os.path.join(IMAGES_DIR, filename + ".jpg")
38
  img_png = os.path.join(IMAGES_DIR, filename + ".png")
39
+ img_jpeg = os.path.join(IMAGES_DIR, filename + ".jpeg")
40
 
41
  if os.path.exists(img_jpg): img_path = img_jpg
42
  elif os.path.exists(img_png): img_path = img_png
43
+ elif os.path.exists(img_jpeg): img_path = img_jpeg
44
+ else:
45
+ # If no image found for this label, skip it
46
+ continue
47
 
 
48
  with open(txt_path, 'r') as f:
49
  lines = f.readlines()
50
+
51
+ if lines:
52
+ try:
53
+ # Read class ID
54
+ class_id = int(lines[0].split()[0])
55
+ if class_id in dataset:
56
+ dataset[class_id].append((img_path, lines))
57
+ except:
58
+ pass
59
+ return dataset
60
+
61
+ def augment_polygon(img_path, lines, new_filename):
 
 
62
  img = cv2.imread(img_path)
63
  if img is None: return
64
+
65
+ action = random.choice(["h_flip", "v_flip", "bright", "noise"])
 
66
  new_lines = []
67
 
68
+ if action == "h_flip":
 
 
 
69
  new_img = cv2.flip(img, 1)
70
  for line in lines:
71
  parts = line.strip().split()
72
  cls = parts[0]
73
  coords = [float(x) for x in parts[1:]]
74
+ new_coords = []
75
+ for i, val in enumerate(coords):
76
+ if i % 2 == 0: new_coords.append(1.0 - val) # X
77
+ else: new_coords.append(val) # Y
78
  new_lines.append(f"{cls} " + " ".join([f"{c:.6f}" for c in new_coords]) + "\n")
79
 
80
+ elif action == "v_flip":
 
81
  new_img = cv2.flip(img, 0)
82
  for line in lines:
83
  parts = line.strip().split()
84
  cls = parts[0]
85
  coords = [float(x) for x in parts[1:]]
86
+ new_coords = []
87
+ for i, val in enumerate(coords):
88
+ if i % 2 == 0: new_coords.append(val) # X
89
+ else: new_coords.append(1.0 - val) # Y
90
  new_lines.append(f"{cls} " + " ".join([f"{c:.6f}" for c in new_coords]) + "\n")
91
 
92
+ elif action == "bright":
93
+ beta = random.randint(-30, 30)
94
+ new_img = cv2.convertScaleAbs(img, alpha=1.0, beta=beta)
95
+ new_lines = lines
 
 
 
 
 
 
96
 
97
+ elif action == "noise":
98
+ noise = np.random.normal(0, 15, img.shape).astype(np.uint8)
99
+ new_img = cv2.add(img, noise)
100
+ new_lines = lines
101
+
102
+ else:
103
+ new_img = img
104
  new_lines = lines
105
 
106
+ # SAVE TO SEPARATE FOLDERS
107
+ cv2.imwrite(os.path.join(IMAGES_DIR, new_filename + ".jpg"), new_img)
108
+ with open(os.path.join(LABELS_DIR, new_filename + ".txt"), 'w') as f:
 
 
109
  f.writelines(new_lines)
110
 
111
  def main():
112
+ print("🚀 Loading Dataset (Separated Folders)...")
113
+ data_map = load_dataset()
114
 
115
+ print("\n📊 Current Counts:")
116
+ for cid in [0, 1, 2]:
117
+ print(f" - {CLASS_NAMES[cid]}: {len(data_map[cid])} images")
 
 
 
 
 
 
 
 
118
 
119
+ print("\n🛠️ augmenting...")
120
+ for cid in [0, 1, 2]:
121
+ items = data_map[cid]
122
+ needed = TARGET_PER_CLASS - len(items)
 
 
 
 
 
123
 
124
+ if needed > 0 and items:
125
+ print(f" -> Generating {needed} images for {CLASS_NAMES[cid]}...")
126
+ for i in tqdm(range(needed)):
127
+ src_img, src_lines = random.choice(items)
128
+ augment_polygon(src_img, src_lines, f"aug_{cid}_{i}")
129
+ elif not items:
130
+ print(f"⚠️ Warning: No images found for {CLASS_NAMES[cid]}!")
131
+
132
+ print("\n✅ Done!")
 
133
 
134
  if __name__ == "__main__":
135
  main()