strawberryPicker / scripts /train_model.py
Gareth
Initial clean commit for Hugging Face
efb1801
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.models import Model
import os
# Data directories
data_dir = 'dataset'
train_datagen = ImageDataGenerator(
rescale=1./255,
validation_split=0.2,
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
horizontal_flip=True
)
train_generator = train_datagen.flow_from_directory(
data_dir,
target_size=(224, 224),
batch_size=32,
class_mode='binary',
subset='training'
)
validation_generator = train_datagen.flow_from_directory(
data_dir,
target_size=(224, 224),
batch_size=32,
class_mode='binary',
subset='validation'
)
# Load pre-trained MobileNetV2
base_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
# Add custom layers
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)
predictions = Dense(1, activation='sigmoid')(x)
model = Model(inputs=base_model.input, outputs=predictions)
# Freeze base layers
for layer in base_model.layers:
layer.trainable = False
# Compile
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
# Train
model.fit(train_generator, validation_data=validation_generator, epochs=10)
# Save model
model.save('strawberry_model.h5')
print("Model trained and saved as strawberry_model.h5")