|
|
import streamlit as st |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import io |
|
|
import torch |
|
|
import timm |
|
|
import json |
|
|
from torchvision.transforms.v2 import ( |
|
|
ToImage, |
|
|
Compose, |
|
|
ToDtype, |
|
|
Normalize, |
|
|
) |
|
|
import pandas as pd |
|
|
import requests |
|
|
|
|
|
|
|
|
st.set_page_config(layout="wide") |
|
|
|
|
|
device = ["cpu", "cuda"][torch.cuda.is_available()] |
|
|
|
|
|
|
|
|
def NativeResize(patch_size, seq_len_range): |
|
|
p, lo, hi = patch_size, *seq_len_range |
|
|
refs = sorted( |
|
|
[ |
|
|
(i / j, i * p, j * p) |
|
|
for i in range(4, 100) |
|
|
for j in range(4, 100) |
|
|
if 0.33 <= i / j <= 3 and lo <= i * j < hi |
|
|
] |
|
|
) |
|
|
|
|
|
def get_ratio(r): |
|
|
return min(refs, key=lambda rr: max(r, rr[0]) / min(r, rr[0]) - 1) |
|
|
|
|
|
def f(im: Image.Image): |
|
|
w, h = im.size |
|
|
_, sw, sh = get_ratio(w / h) |
|
|
return im.resize((sw, sh), resample=Image.Resampling.BICUBIC) |
|
|
|
|
|
return f |
|
|
|
|
|
def load_json_from_url(url): |
|
|
try: |
|
|
response = requests.get(url) |
|
|
response.raise_for_status() |
|
|
parsed_json = json.loads(response.text) |
|
|
return parsed_json |
|
|
except requests.exceptions.RequestException as e: |
|
|
print(f"Error fetching data from URL: {e}") |
|
|
return None |
|
|
except json.JSONDecodeError as e: |
|
|
print(f"Error decoding JSON data: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
@st.cache_data |
|
|
def load_tags(): |
|
|
freqs = load_json_from_url("https://huggingface.co/gustproof/dnbr-tagger-preview1/raw/main/freqs.json") |
|
|
freqs = [*freqs, (("PLACEHOLDER", 0))] |
|
|
return freqs |
|
|
|
|
|
|
|
|
tags = load_tags() |
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_model(): |
|
|
torch.set_grad_enabled(False) |
|
|
model = ( |
|
|
timm.create_model( |
|
|
"hf_hub:gustproof/dnbr-tagger-preview1", |
|
|
pretrained=True, |
|
|
dynamic_img_size=True, |
|
|
) |
|
|
.eval() |
|
|
.to(device) |
|
|
) |
|
|
print("loaded model") |
|
|
tf = Compose( |
|
|
[ |
|
|
ToImage(), |
|
|
ToDtype(torch.float, scale=True), |
|
|
Normalize(mean=[0.4850, 0.4560, 0.4060], std=[0.2290, 0.2240, 0.2250]), |
|
|
] |
|
|
) |
|
|
|
|
|
class Model: |
|
|
def __init__(self): |
|
|
self.class_names = load_tags() |
|
|
|
|
|
def predict(self, img): |
|
|
x = tf(img).unsqueeze(0).to(device) |
|
|
probs = model(x).squeeze(0).sigmoid().cpu() |
|
|
return probs |
|
|
|
|
|
return Model() |
|
|
|
|
|
|
|
|
model = load_model() |
|
|
|
|
|
|
|
|
st.title("Tagger demo") |
|
|
st.write("Model: [gustproof/dnbr-tagger-preview1](https://huggingface.co/gustproof/dnbr-tagger-preview1)") |
|
|
st.write( |
|
|
"Upload an image to see predicted labels." |
|
|
) |
|
|
st.write("---") |
|
|
|
|
|
|
|
|
st.sidebar.header("Settings") |
|
|
|
|
|
confidence_threshold = st.sidebar.slider( |
|
|
"Threshold (recommended: ~0.4-~0.6)", |
|
|
min_value=0.0, |
|
|
max_value=1.0, |
|
|
value=0.5, |
|
|
step=0.01, |
|
|
) |
|
|
|
|
|
|
|
|
col1, col2 = st.columns(2) |
|
|
|
|
|
with col1: |
|
|
st.header("Upload Image") |
|
|
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png", "webp"]) |
|
|
|
|
|
if uploaded_file is not None: |
|
|
|
|
|
image_bytes = uploaded_file.getvalue() |
|
|
|
|
|
try: |
|
|
image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
|
image = NativeResize(14, (270, 301))(image) |
|
|
w, h = image.size |
|
|
st.image(image, caption=f"Resized: {w}x{h}") |
|
|
except Exception as e: |
|
|
st.error(f"Error opening image: {e}") |
|
|
st.warning("Please upload a valid image file") |
|
|
uploaded_file = None |
|
|
|
|
|
with col2: |
|
|
st.header("Predictions") |
|
|
if uploaded_file is not None: |
|
|
with st.spinner("Computing..."): |
|
|
try: |
|
|
scores = model.predict(image) |
|
|
filtered_results = [ |
|
|
(i, p) for i, p in enumerate(scores) if p >= confidence_threshold |
|
|
] |
|
|
filtered_results.sort(key=lambda x: x[1], reverse=True) |
|
|
|
|
|
if filtered_results: |
|
|
get_category = lambda ti: ( |
|
|
"Rating" if ti < 4 else "General" if ti < 8856 else "Character" |
|
|
) |
|
|
df = pd.DataFrame( |
|
|
[ |
|
|
(i, tags[ti][0], f"{p:.4f}", get_category(ti), tags[ti][1]) |
|
|
for i, (ti, p) in enumerate(filtered_results[:200], 1) |
|
|
], |
|
|
columns=[ |
|
|
"Rank", |
|
|
"Label", |
|
|
"Score", |
|
|
"Category", |
|
|
"Dataset frequency", |
|
|
], |
|
|
) |
|
|
st.dataframe( |
|
|
df, |
|
|
hide_index=True, |
|
|
column_config={ |
|
|
"Dataset frequency": st.column_config.NumberColumn( |
|
|
format="localized" |
|
|
) |
|
|
}, |
|
|
) |
|
|
else: |
|
|
st.info("No labels meet the current threshold.") |
|
|
|
|
|
except Exception as e: |
|
|
st.error("An error occurred during prediction or processing:") |
|
|
st.exception(e) |
|
|
|
|
|
else: |
|
|
st.info("Upload an image using the panel on the left to see predictions.") |
|
|
|