Spaces:
Sleeping
Sleeping
| import torch | |
| import gradio as gr | |
| import prediction | |
| import model | |
| import diffusion_loss | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| pipe = model.initialize_diffusion_model() | |
| def generate(prompt, loss_function=None): | |
| return prediction.predict(prompt=prompt, pipe=pipe, loss_function=loss_function) | |
| def process_input(prompt, loss_function, button): | |
| if button: | |
| if loss_function is None or loss_function == "No Loss": | |
| return generate(prompt, loss_function=None) | |
| elif loss_function == "Blue Channel": | |
| return generate(prompt, loss_function=diffusion_loss.blue_channel) | |
| elif loss_function == "Saturation": | |
| return generate(prompt, loss_function=diffusion_loss.saturation) | |
| elif loss_function == "Elastic Deformation": | |
| return generate(prompt, loss_function=diffusion_loss.elastic_transform) | |
| else: | |
| return generate(prompt, loss_function=None) | |
| else: | |
| return None | |
| iface = gr.Interface( | |
| fn=process_input, | |
| inputs=[ | |
| gr.Textbox("prompt", label="Enter Prompt"), | |
| gr.Dropdown(["No Loss", "Blue Channel", "Saturation", 'Elastic Deformation'], label='Choose Augmentation'), | |
| gr.Button("Loss Function")], | |
| outputs = gr.Image(type="pil") | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch(show_api=False, share=True) |