Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -3,13 +3,13 @@ import gradio as gr
|
|
| 3 |
import torch
|
| 4 |
import numpy as np
|
| 5 |
from PIL import Image
|
| 6 |
-
from transformers import AutoModel, AutoProcessor
|
| 7 |
from torch import nn
|
| 8 |
import torch.nn.functional as F
|
| 9 |
from datasets import load_dataset
|
| 10 |
from torch.utils.data import Dataset, DataLoader
|
| 11 |
import os
|
| 12 |
from tqdm import tqdm
|
|
|
|
| 13 |
|
| 14 |
class SDDataset(Dataset):
|
| 15 |
def __init__(self, dataset, processor, model_to_idx, token_to_idx, max_samples=5000):
|
|
@@ -49,7 +49,7 @@ class SDRecommenderModel(nn.Module):
|
|
| 49 |
|
| 50 |
def forward(self, image_features):
|
| 51 |
# Get Florence embeddings
|
| 52 |
-
features = self.florence.get_image_features(image_features)
|
| 53 |
|
| 54 |
# Generate model and prompt recommendations
|
| 55 |
model_logits = self.model_head(features)
|
|
@@ -62,16 +62,16 @@ class SDRecommender:
|
|
| 62 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 63 |
print(f"Using device: {self.device}")
|
| 64 |
|
| 65 |
-
# Load Florence model and processor
|
| 66 |
print("Loading Florence model and processor...")
|
| 67 |
-
self.processor =
|
| 68 |
"microsoft/Florence-2-large",
|
| 69 |
trust_remote_code=True
|
| 70 |
)
|
| 71 |
-
self.florence =
|
| 72 |
"microsoft/Florence-2-large",
|
| 73 |
trust_remote_code=True
|
| 74 |
-
)
|
| 75 |
|
| 76 |
# Load dataset
|
| 77 |
print("Loading dataset...")
|
|
@@ -139,7 +139,7 @@ class SDRecommender:
|
|
| 139 |
|
| 140 |
for batch_idx, (images, model_labels, prompt_labels) in enumerate(progress_bar):
|
| 141 |
# Move everything to device
|
| 142 |
-
images =
|
| 143 |
model_labels = model_labels.to(self.device)
|
| 144 |
prompt_labels = prompt_labels.to(self.device)
|
| 145 |
|
|
@@ -170,7 +170,8 @@ class SDRecommender:
|
|
| 170 |
image = Image.open(image)
|
| 171 |
|
| 172 |
# Process image
|
| 173 |
-
inputs = self.processor(images=image, return_tensors="pt")
|
|
|
|
| 174 |
|
| 175 |
# Get model predictions
|
| 176 |
self.model.eval()
|
|
|
|
| 3 |
import torch
|
| 4 |
import numpy as np
|
| 5 |
from PIL import Image
|
|
|
|
| 6 |
from torch import nn
|
| 7 |
import torch.nn.functional as F
|
| 8 |
from datasets import load_dataset
|
| 9 |
from torch.utils.data import Dataset, DataLoader
|
| 10 |
import os
|
| 11 |
from tqdm import tqdm
|
| 12 |
+
from transformers import Florence2Model, Florence2Processor
|
| 13 |
|
| 14 |
class SDDataset(Dataset):
|
| 15 |
def __init__(self, dataset, processor, model_to_idx, token_to_idx, max_samples=5000):
|
|
|
|
| 49 |
|
| 50 |
def forward(self, image_features):
|
| 51 |
# Get Florence embeddings
|
| 52 |
+
features = self.florence.get_image_features(**image_features)
|
| 53 |
|
| 54 |
# Generate model and prompt recommendations
|
| 55 |
model_logits = self.model_head(features)
|
|
|
|
| 62 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 63 |
print(f"Using device: {self.device}")
|
| 64 |
|
| 65 |
+
# Load Florence model and processor
|
| 66 |
print("Loading Florence model and processor...")
|
| 67 |
+
self.processor = Florence2Processor.from_pretrained(
|
| 68 |
"microsoft/Florence-2-large",
|
| 69 |
trust_remote_code=True
|
| 70 |
)
|
| 71 |
+
self.florence = Florence2Model.from_pretrained(
|
| 72 |
"microsoft/Florence-2-large",
|
| 73 |
trust_remote_code=True
|
| 74 |
+
).to(self.device)
|
| 75 |
|
| 76 |
# Load dataset
|
| 77 |
print("Loading dataset...")
|
|
|
|
| 139 |
|
| 140 |
for batch_idx, (images, model_labels, prompt_labels) in enumerate(progress_bar):
|
| 141 |
# Move everything to device
|
| 142 |
+
images = {k: v.to(self.device) for k, v in images.items()}
|
| 143 |
model_labels = model_labels.to(self.device)
|
| 144 |
prompt_labels = prompt_labels.to(self.device)
|
| 145 |
|
|
|
|
| 170 |
image = Image.open(image)
|
| 171 |
|
| 172 |
# Process image
|
| 173 |
+
inputs = self.processor(images=image, return_tensors="pt")
|
| 174 |
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 175 |
|
| 176 |
# Get model predictions
|
| 177 |
self.model.eval()
|