import tensorflow as tf from tensorflow.keras.preprocessing.image import ImageDataGenerator from tensorflow.keras.applications import MobileNetV2 from tensorflow.keras.layers import Dense, GlobalAveragePooling2D from tensorflow.keras.models import Model import os # Data directories data_dir = 'dataset' train_datagen = ImageDataGenerator( rescale=1./255, validation_split=0.2, rotation_range=20, width_shift_range=0.2, height_shift_range=0.2, horizontal_flip=True ) train_generator = train_datagen.flow_from_directory( data_dir, target_size=(224, 224), batch_size=32, class_mode='binary', subset='training' ) validation_generator = train_datagen.flow_from_directory( data_dir, target_size=(224, 224), batch_size=32, class_mode='binary', subset='validation' ) # Load pre-trained MobileNetV2 base_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=(224, 224, 3)) # Add custom layers x = base_model.output x = GlobalAveragePooling2D()(x) x = Dense(1024, activation='relu')(x) predictions = Dense(1, activation='sigmoid')(x) model = Model(inputs=base_model.input, outputs=predictions) # Freeze base layers for layer in base_model.layers: layer.trainable = False # Compile model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy']) # Train model.fit(train_generator, validation_data=validation_generator, epochs=10) # Save model model.save('strawberry_model.h5') print("Model trained and saved as strawberry_model.h5")