rice-scanner / task_3_model.py
NickNam2710's picture
update load model
1413913
import os
import tensorflow as tf
from huggingface_hub import hf_hub_download
@tf.keras.utils.register_keras_serializable(package="custom_models")
class InceptionModel(tf.keras.Model):
def __init__(
self,
dropout_rate: float,
l2_reg: float,
dense_units: int,
*,
name="InceptionV3Model",
**kwargs,
):
super().__init__(name=name, **kwargs)
self.dropout_rate = dropout_rate
self.l2_reg = l2_reg
self.dense_units = dense_units
l2 = tf.keras.regularizers.L2(l2_reg)
inception_base = tf.keras.applications.InceptionV3(
include_top=False,
input_shape=(256, 256, 3),
pooling='max',
weights=None
)
inputs = tf.keras.layers.Input(shape=(256, 256, 3))
x = tf.keras.layers.Rescaling(1./255.)(inputs)
x = inception_base(x)
x = tf.keras.layers.Dropout(dropout_rate)(x)
x = tf.keras.layers.Dense(dense_units, activation="relu", kernel_regularizer=l2)(x)
outputs = tf.keras.layers.Dense(1)(x)
self.net = tf.keras.Model(inputs, outputs)
def call(self, inputs, training=False):
return self.net(inputs, training=training)
def get_config(self):
config = super().get_config()
config.update({
"dropout_rate": self.dropout_rate,
"l2_reg": self.l2_reg,
"dense_units": self.dense_units,
})
return config
@classmethod
def from_config(cls, config):
dropout_rate = config.pop("dropout_rate")
l2_reg = config.pop("l2_reg")
dense_units = config.pop("dense_units")
return cls(dropout_rate, l2_reg, dense_units, **config)
@tf.function(
input_signature=[(
tf.TensorSpec([None, 256, 256, 3], tf.float32),
tf.TensorSpec([None,], tf.float32),
)],
reduce_retracing=True,
)
def train_step(self, data):
x, y = data
y = tf.reshape(y, (-1, 1))
with tf.GradientTape() as tape:
y_pred = self.net(x, training=True)
loss = self.compute_loss(y=y, y_pred=y_pred)
grads = tape.gradient(loss, self.trainable_variables)
self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
for metric in self.metrics:
if metric.name == "loss":
metric.update_state(loss)
else:
metric.update_state(y, y_pred)
return {m.name: m.result() for m in self.metrics}
def load_task_3_model(model_name='task_3_ensemble_model_og_data.keras'):
model_path = hf_hub_download(
repo_id="NickNam2710/predict_rice_diseases",
filename=model_name,
revision="main"
)
model = tf.keras.models.load_model(model_path)
return model