sk2003 commited on
Commit
73142ed
·
verified ·
1 Parent(s): 10cb574

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -19
app.py CHANGED
@@ -6,31 +6,31 @@ from PIL import Image
6
  from huggingface_hub import hf_hub_download
7
  import torch.nn as nn
8
 
9
- # Loading the ResNet50 model from your Hugging Face repository
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
- resnet50_model_path = hf_hub_download(repo_id="sk2003/style_recognizer_resnet", filename="resnet50_model.pth")
12
 
13
- # ResNet50 model
14
- resnet50 = models.resnet50(pretrained=True)
15
- for param in resnet50.parameters():
16
- param.requires_grad = False
 
 
17
 
18
  num_classes = 8
19
- resnet50.fc = nn.Linear(resnet50.fc.in_features, num_classes)
20
- resnet50 = resnet50.to(device)
21
 
22
- # Loading the saved state dict
23
- checkpoint = torch.load(resnet50_model_path, map_location=device)
24
- resnet50.load_state_dict(checkpoint['model_state_dict'])
25
- resnet50.eval()
26
 
27
- # Fine-tuned Stable Diffusion model from your Hugging Face repository
28
  model_id = "sk2003/room-styler"
29
  pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
30
  pipe.to(device)
31
 
32
  # Prediction function for the ResNet50 model
33
- def predict_and_show(image):
34
  transform = transforms.Compose([
35
  transforms.Resize((224, 224)),
36
  transforms.ToTensor(),
@@ -42,10 +42,10 @@ def predict_and_show(image):
42
  outputs = resnet50(image_tensor)
43
  _, predicted = torch.max(outputs.data, 1)
44
 
45
- class_names = ["Classic", "Modern", "Vintage", "Glamour", "Scandinavian", "Rustic", "ArtDeco", "Industrial"]
46
- predicted_label = class_names[predicted.item()]
47
 
48
- return predicted_label
49
 
50
  # Generation function for the Stable Diffusion model
51
  def generate_image(prompt):
@@ -54,14 +54,16 @@ def generate_image(prompt):
54
 
55
  # Gradio interface
56
  with gr.Blocks() as demo:
57
- gr.Markdown("## Room Style Recognition and Generation")
58
 
 
59
  with gr.Tab("Recognize Room Style"):
60
  image_input = gr.Image(type="pil")
61
  label_output = gr.Textbox()
62
  btn_predict = gr.Button("Predict Style")
63
- btn_predict.click(predict_and_show, inputs=image_input, outputs=label_output)
64
 
 
65
  with gr.Tab("Generate Room Style"):
66
  text_input = gr.Textbox(placeholder="Enter a prompt for room style...")
67
  image_output = gr.Image()
@@ -69,3 +71,4 @@ with gr.Blocks() as demo:
69
  btn_generate.click(generate_image, inputs=text_input, outputs=image_output)
70
 
71
  demo.launch()
 
 
6
  from huggingface_hub import hf_hub_download
7
  import torch.nn as nn
8
 
 
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
10
 
11
+ # Finetuned Resnet-50 model is downloaded
12
+ vgg16_model_path = hf_hub_download(repo_id="sk2003/style_recognizer_vgg", filename="vgg16_model.pth")
13
+
14
+ vgg16 = models.vgg16(pretrained=True)
15
+ for param in vgg16.parameters():
16
+ param.requires_grad = False # freezing parameters
17
 
18
  num_classes = 8
19
+ vgg16.fc = nn.Linear(vgg16.fc.in_features, num_classes)
20
+ vgg16 = vgg16.to(device)
21
 
22
+ # Loading the model
23
+ checkpoint = torch.load(vgg16_model_path, map_location=device)
24
+ vgg16.load_state_dict(checkpoint['model_state_dict'])
25
+ vgg16.eval() # setting to evaluation mode to disable batch-norm and dropout layers
26
 
27
+ # Fine-tuned Stable Diffusion model
28
  model_id = "sk2003/room-styler"
29
  pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
30
  pipe.to(device)
31
 
32
  # Prediction function for the ResNet50 model
33
+ def predict(image):
34
  transform = transforms.Compose([
35
  transforms.Resize((224, 224)),
36
  transforms.ToTensor(),
 
42
  outputs = resnet50(image_tensor)
43
  _, predicted = torch.max(outputs.data, 1)
44
 
45
+ classes = ["Classic", "Modern", "Vintage", "Glamour", "Scandinavian", "Rustic", "ArtDeco", "Industrial"]
46
+ pred = classes[predicted.item()]
47
 
48
+ return pred
49
 
50
  # Generation function for the Stable Diffusion model
51
  def generate_image(prompt):
 
54
 
55
  # Gradio interface
56
  with gr.Blocks() as demo:
57
+ gr.Markdown("## Room Style Recognition and Generation") # title
58
 
59
+ # 1st tab
60
  with gr.Tab("Recognize Room Style"):
61
  image_input = gr.Image(type="pil")
62
  label_output = gr.Textbox()
63
  btn_predict = gr.Button("Predict Style")
64
+ btn_predict.click(predict, inputs=image_input, outputs=label_output)
65
 
66
+ # 2nd tab
67
  with gr.Tab("Generate Room Style"):
68
  text_input = gr.Textbox(placeholder="Enter a prompt for room style...")
69
  image_output = gr.Image()
 
71
  btn_generate.click(generate_image, inputs=text_input, outputs=image_output)
72
 
73
  demo.launch()
74
+