Spaces:
Build error
Build error
File size: 2,410 Bytes
991bb01 6d53984 c25bef4 991bb01 de1fe70 c25bef4 3dc57ed 991bb01 3dc57ed 991bb01 3dc57ed 991bb01 14fb364 991bb01 5919981 991bb01 5919981 ecf5b45 991bb01 8ba0f0d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import io
import numpy as np
import pandas as pd
from PIL import Image as PILImage
import torch
from era_data import TabletPeriodDataset
from VAE_model_tablets_class import VAE
import gradio as gr
device = 'cuda' if torch.cuda.is_available() else 'cpu'
num_classes = len(TabletPeriodDataset.PERIOD_INDICES)
class_weights = torch.load("class_weights_period.pt")
checkpoint_path = 'epoch=22-step=213621.ckpt'
vae_model = VAE.load_from_checkpoint(checkpoint_path,image_channels=1,z_dim=12, lr =0.0001, use_classification_loss=True, num_classes=num_classes,
loss_type="weighted", class_weights=class_weights, device = device)
vae_model.eval()
# Load your dataframe encoding
df_encodings = pd.read_csv('df_vae_encoding_April16_all.csv')
df_means = df_encodings.drop(["Period", "Genre", "Genre_Name", "CDLI_id"], axis = 1).groupby("Period_Name").mean().reset_index()
period_names = df_means['Period_Name'].unique()
def get_image_from_period(period_name):
period_data = torch.from_numpy(df_means[df_means["Period_Name"] == period_name].drop(["Period_Name"], axis=1).values[0].astype('float32'))
return period_data
def generate_image(period1, period2, interpolation_value):
image1 = get_image_from_period(period1)
image2 = get_image_from_period(period2)
i = interpolation_value
new_tablet = (1-i) * image1 + i * image2
new_tab_long = vae_model.fc3(new_tablet).unsqueeze(0)
with torch.no_grad():
generated_image = vae_model.decoder(new_tab_long)
generated_image = generated_image[0][0].detach().cpu().numpy()
generated_image = (generated_image * 255).astype(np.uint8)
pil_img = PILImage.fromarray(generated_image)
return pil_img
def update_image(dropdown1, dropdown2, slider):
iface.update(gr.Image(value=generate_image(dropdown1, dropdown2, slider)))
with gr.Blocks() as inputOutput:
dropdown1 = gr.Dropdown(choices=period_names.tolist(), label="Period 1")
dropdown2 = gr.Dropdown(choices=period_names.tolist(), label="Period 2")
slider = gr.Slider(0, 1, step=0.1, label="Interpolation")
slider.change(update_image, dropdown1, dropdown2, slider)
# Define the Gradio interface
iface = gr.Interface(
fn=generate_image,
inputs=[dropdown1, dropdown2, slider],
outputs=gr.Image(label="", height=250, width=250),
allow_flagging="never")
if __name__ == "__main__":
iface.launch(share=True)
|