Tfttest / app.py
Badumetsibb's picture
Update app.py
b624c59 verified
raw
history blame contribute delete
944 Bytes
import gradio as gr
import pandas as pd
import numpy as np
import tensorflow as tf
from sklearn.preprocessing import StandardScaler
from keras.saving import register_keras_serializable
@register_keras_serializable()
class SimplifiedTFT_Iter3(tf.keras.Model):
...
# Load your trained TFT model
model = tf.keras.models.load_model("tft_model.keras", compile=False)
# Load scalers if saved separately (optional), or define here again
def predict_from_csv(file):
df = pd.read_csv(file.name)
# Perform the same preprocessing as during training
# This must match what you did before model.fit()
# For demo, let's assume the last N rows have the correct shape
input_data = np.expand_dims(df.tail(1).values, axis=0)
# Make prediction
pred = model.predict(input_data)
return f"Prediction: {pred.flatten()[0]}"
# Gradio interface
gr.Interface(fn=predict_from_csv, inputs="file", outputs="text").launch()