File size: 1,927 Bytes
1b9e4e6 6cd7aaf 1b9e4e6 6cd7aaf 1b9e4e6 6cd7aaf 1b9e4e6 6cd7aaf 1b9e4e6 6cd7aaf 1b9e4e6 6cd7aaf 1b9e4e6 6cd7aaf 1b9e4e6 6cd7aaf 1b9e4e6 6cd7aaf 1b9e4e6 6cd7aaf 1b9e4e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
!pip install -U adapter-transformers
!pip install -U transformers
import gradio as gr
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import torch
# Load the model and processor
model = CLIPModel.from_pretrained("Taarhoinc/TaarhoGen1")
processor = CLIPProcessor.from_pretrained("Taarhoinc/TaarhoGen1")
# Define the function to describe a floor plan
def describe_floorplan(floorplan_image: Image.Image, top_k: int = 3):
"""Describes a floor plan drawing by listing components."""
# Define a list of common floor plan components
components = [
"bedroom",
"kitchen",
"bathroom",
"living room",
"dining room",
"hallway",
"garage",
"balcony",
"stairs",
"door",
"window",
]
# Preprocess the image and text prompts
inputs = processor(
text=components, images=floorplan_image, return_tensors="pt", padding=True
)
# Get the logits (similarity scores)
with torch.no_grad():
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image
# Get the predicted probabilities
probs = logits_per_image.softmax(dim=1).cpu().numpy()[0]
# Get the indices of the top-k components
top_k_indices = probs.argsort()[-top_k:][::-1]
# Get the top-k components
detected_components = [components[i] for i in top_k_indices]
return ", ".join(detected_components) # Return as a comma-separated string
# Create the Gradio interface
gr.Interface(
fn=describe_floorplan,
inputs=[
gr.Image(label="Upload a floor plan drawing", type="pil"),
gr.Slider(1, 10, step=1, value=3, label="Number of components to detect"),
],
outputs=gr.Label(label="Detected Components"),
title="Floor Plan Description with TaarhoGen1",
description="Upload a floor plan drawing to get a list of detected components.",
).launch() |