project / app.py
nikethanreddy's picture
Update app.py
1a8155e verified
raw
history blame
686 Bytes
import gradio as gr
import numpy as np
from tensorflow.keras.models import load_model
from tkan import TKAN
from tkat import TKAT
from keras.utils import custom_object_scope
# Load the model with custom objects
with custom_object_scope({"TKAN": TKAN, "TKAT": TKAT}):
model = load_model("best_model_TKAN_nahead_1 (2).keras")
# Define predict function
def predict(pm25, pm10, co, temp):
input_data = np.array([[pm25, pm10, co, temp]])
output = model.predict(input_data)
return float(output[0][0])
# Gradio interface
interface = gr.Interface(
fn=predict,
inputs=[gr.Number(), gr.Number(), gr.Number(), gr.Number()],
outputs=gr.Number()
)
interface.launch()