Virgual_Reality / app.py
BrandDead's picture
Update paths to models directory
312451f
import gradio as gr
from tensorflow.keras.models import load_model
import numpy as np
import joblib
# Load models
rnn_model = load_model('models/virgil_rnn_model.keras', compile=False)
gan_generator = load_model('models/virgil_gan_generator.keras', compile=False)
vae_model = load_model('models/virgil_autoencoder_model.keras', compile=False)
rf_model = joblib.load('models/virgil_rf_finetuned_model.pkl')
# Define functions for each model
def learn_fashion(input_data):
input_array = np.array([input_data])
prediction = rf_model.predict(input_array)
return prediction[0]
def respond_like_virgil(input_data):
input_array = np.array([input_data]).reshape(1, -1)
prediction = rnn_model.predict(input_array)
return prediction[0]
def design_with_gan(input_data):
input_array = np.array([input_data]).reshape(1, -1)
generated_output = gan_generator.predict(input_array)
return generated_output[0]
# Create a Gradio interface
def choose_action(action, input_data):
if action == "Learn Fashion and Branding":
return learn_fashion(input_data)
elif action == "Respond Like Virgil":
return respond_like_virgil(input_data)
elif action == "Design with GAN":
return design_with_gan(input_data)
# Setup the interface
interface = gr.Interface(
fn=choose_action,
inputs=["dropdown", "text"], # User selects action, then inputs data
outputs="text", # Outputs the model's prediction
live=True
)
# Launch the interface
interface.launch()