Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from diffusers import StableDiffusionPipeline | |
| import torch | |
| from torchvision import models, transforms | |
| from PIL import Image | |
| from huggingface_hub import hf_hub_download | |
| import torch.nn as nn | |
| # Set the device | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Download the fine-tuned VGG16 model | |
| vgg16_model_path = hf_hub_download(repo_id="sk2003/style_recognizer_vgg", filename="vgg16_model.pth") | |
| # Load the VGG16 model with pre-trained weights | |
| vgg16 = models.vgg16(pretrained=True) | |
| for param in vgg16.parameters(): | |
| param.requires_grad = False # Freeze parameters | |
| # Update the last fully connected layer to match the number of classes | |
| num_classes = 8 | |
| vgg16.classifier[6] = nn.Linear(vgg16.classifier[6].in_features, num_classes) | |
| vgg16 = vgg16.to(device) | |
| # Load the fine-tuned model state | |
| checkpoint = torch.load(vgg16_model_path, map_location=device) | |
| vgg16.load_state_dict(checkpoint['model_state_dict']) | |
| vgg16.eval() # Set the model to evaluation mode | |
| # Load the fine-tuned Stable Diffusion model | |
| model_id = "sk2003/room-styler" | |
| pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16) | |
| pipe.to(device) | |
| # Prediction function for the VGG16 model | |
| def predict(image): | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| image_tensor = transform(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| outputs = vgg16(image_tensor) | |
| _, predicted = torch.max(outputs.data, 1) | |
| classes = ["Classic", "Modern", "Vintage", "Glamour", "Scandinavian", "Rustic", "ArtDeco", "Industrial"] | |
| pred = classes[predicted.item()] | |
| return pred | |
| # Generation function for the Stable Diffusion model | |
| def generate_image(prompt): | |
| image = pipe(prompt).images[0] | |
| return image | |
| # Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Room Style Recognition and Generation") # Title | |
| # 1st tab | |
| with gr.Tab("Recognize Room Style"): | |
| image_input = gr.Image(type="pil") | |
| label_output = gr.Textbox() | |
| btn_predict = gr.Button("Predict Style") | |
| btn_predict.click(predict, inputs=image_input, outputs=label_output) | |
| # 2nd tab | |
| with gr.Tab("Generate Room Style"): | |
| text_input = gr.Textbox(placeholder="Enter a prompt for room style...") | |
| image_output = gr.Image() | |
| btn_generate = gr.Button("Generate Image") | |
| btn_generate.click(generate_image, inputs=text_input, outputs=image_output) | |
| demo.launch() | |