import gradio as gr from diffusers import StableDiffusionPipeline import torch from torchvision import models, transforms from PIL import Image from huggingface_hub import hf_hub_download import torch.nn as nn # Set the device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Download the fine-tuned VGG16 model vgg16_model_path = hf_hub_download(repo_id="sk2003/style_recognizer_vgg", filename="vgg16_model.pth") # Load the VGG16 model with pre-trained weights vgg16 = models.vgg16(pretrained=True) for param in vgg16.parameters(): param.requires_grad = False # Freeze parameters # Update the last fully connected layer to match the number of classes num_classes = 8 vgg16.classifier[6] = nn.Linear(vgg16.classifier[6].in_features, num_classes) vgg16 = vgg16.to(device) # Load the fine-tuned model state checkpoint = torch.load(vgg16_model_path, map_location=device) vgg16.load_state_dict(checkpoint['model_state_dict']) vgg16.eval() # Set the model to evaluation mode # Load the fine-tuned Stable Diffusion model model_id = "sk2003/room-styler" pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16) pipe.to(device) # Prediction function for the VGG16 model def predict(image): transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) image_tensor = transform(image).unsqueeze(0).to(device) with torch.no_grad(): outputs = vgg16(image_tensor) _, predicted = torch.max(outputs.data, 1) classes = ["Classic", "Modern", "Vintage", "Glamour", "Scandinavian", "Rustic", "ArtDeco", "Industrial"] pred = classes[predicted.item()] return pred # Generation function for the Stable Diffusion model def generate_image(prompt): image = pipe(prompt).images[0] return image # Gradio interface with gr.Blocks() as demo: gr.Markdown("## Room Style Recognition and Generation") # Title # 1st tab with gr.Tab("Recognize Room Style"): image_input = gr.Image(type="pil") label_output = gr.Textbox() btn_predict = gr.Button("Predict Style") btn_predict.click(predict, inputs=image_input, outputs=label_output) # 2nd tab with gr.Tab("Generate Room Style"): text_input = gr.Textbox(placeholder="Enter a prompt for room style...") image_output = gr.Image() btn_generate = gr.Button("Generate Image") btn_generate.click(generate_image, inputs=text_input, outputs=image_output) demo.launch()