| import tensorflow as tf | |
| from tensorflow.keras.preprocessing.image import ImageDataGenerator | |
| from tensorflow.keras.applications import MobileNetV2 | |
| from tensorflow.keras.layers import Dense, GlobalAveragePooling2D | |
| from tensorflow.keras.models import Model | |
| import os | |
| # Data directories | |
| data_dir = 'dataset' | |
| train_datagen = ImageDataGenerator( | |
| rescale=1./255, | |
| validation_split=0.2, | |
| rotation_range=20, | |
| width_shift_range=0.2, | |
| height_shift_range=0.2, | |
| horizontal_flip=True | |
| ) | |
| train_generator = train_datagen.flow_from_directory( | |
| data_dir, | |
| target_size=(224, 224), | |
| batch_size=32, | |
| class_mode='binary', | |
| subset='training' | |
| ) | |
| validation_generator = train_datagen.flow_from_directory( | |
| data_dir, | |
| target_size=(224, 224), | |
| batch_size=32, | |
| class_mode='binary', | |
| subset='validation' | |
| ) | |
| # Load pre-trained MobileNetV2 | |
| base_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=(224, 224, 3)) | |
| # Add custom layers | |
| x = base_model.output | |
| x = GlobalAveragePooling2D()(x) | |
| x = Dense(1024, activation='relu')(x) | |
| predictions = Dense(1, activation='sigmoid')(x) | |
| model = Model(inputs=base_model.input, outputs=predictions) | |
| # Freeze base layers | |
| for layer in base_model.layers: | |
| layer.trainable = False | |
| # Compile | |
| model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy']) | |
| # Train | |
| model.fit(train_generator, validation_data=validation_generator, epochs=10) | |
| # Save model | |
| model.save('strawberry_model.h5') | |
| print("Model trained and saved as strawberry_model.h5") |