File size: 5,542 Bytes
c05ccb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6851e01
c05ccb6
 
 
 
 
 
 
 
1cc1d5e
 
c05ccb6
 
1cc1d5e
c05ccb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1cc1d5e
c05ccb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import streamlit as st
import numpy as np
from PIL import Image
import io  # Used to handle image bytes
import torch
import timm
import json
from torchvision.transforms.v2 import (
    ToImage,
    Compose,
    ToDtype,
    Normalize,
)
import pandas as pd
import requests


st.set_page_config(layout="wide")

device = ["cpu", "cuda"][torch.cuda.is_available()]


def NativeResize(patch_size, seq_len_range):
    p, lo, hi = patch_size, *seq_len_range
    refs = sorted(
        [
            (i / j, i * p, j * p)
            for i in range(4, 100)
            for j in range(4, 100)
            if 0.33 <= i / j <= 3 and lo <= i * j < hi
        ]
    )

    def get_ratio(r):
        return min(refs, key=lambda rr: max(r, rr[0]) / min(r, rr[0]) - 1)

    def f(im: Image.Image):
        w, h = im.size
        _, sw, sh = get_ratio(w / h)
        return im.resize((sw, sh), resample=Image.Resampling.BICUBIC)

    return f

def load_json_from_url(url):
    try:
        response = requests.get(url)
        response.raise_for_status()  # Raise an exception for bad status codes
        parsed_json = json.loads(response.text)
        return parsed_json
    except requests.exceptions.RequestException as e:
        print(f"Error fetching data from URL: {e}")
        return None
    except json.JSONDecodeError as e:
        print(f"Error decoding JSON data: {e}")
        return None


@st.cache_data
def load_tags():
    freqs = load_json_from_url("https://huggingface.co/gustproof/dnbr-tagger-preview1/raw/main/freqs.json")
    freqs = [*freqs, (("PLACEHOLDER", 0))]
    return freqs


tags = load_tags()


@st.cache_resource
def load_model():
    torch.set_grad_enabled(False)
    model = (
        timm.create_model(
            "hf_hub:gustproof/dnbr-tagger-preview1",
            pretrained=True,
            dynamic_img_size=True,
        )
        .eval()
        .to(device)
    )
    print("loaded model")
    tf = Compose(
        [
            ToImage(),
            ToDtype(torch.float, scale=True),
            Normalize(mean=[0.4850, 0.4560, 0.4060], std=[0.2290, 0.2240, 0.2250]),
        ]
    )

    class Model:
        def __init__(self):
            self.class_names = load_tags()

        def predict(self, img):
            x = tf(img).unsqueeze(0).to(device)
            probs = model(x).squeeze(0).sigmoid().cpu()
            return probs

    return Model()


model = load_model()

# --- Streamlit App Layout ---
st.title("Tagger demo")
st.write("Model: [gustproof/dnbr-tagger-preview1](https://huggingface.co/gustproof/dnbr-tagger-preview1)")
st.write(
    "Upload an image to see predicted labels."
)
st.write("---")

# --- Sidebar for Controls ---
st.sidebar.header("Settings")
# Confidence Threshold Slider
confidence_threshold = st.sidebar.slider(
    "Threshold (recommended: ~0.4-~0.6)",
    min_value=0.0,
    max_value=1.0,
    value=0.5,  # Default threshold
    step=0.01,
)

# --- Main Area ---
col1, col2 = st.columns(2)

with col1:
    st.header("Upload Image")
    uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png", "webp"])

    if uploaded_file is not None:
        # Read the image bytes
        image_bytes = uploaded_file.getvalue()
        # Display the uploaded image
        try:
            image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
            image = NativeResize(14, (270, 301))(image)
            w, h = image.size
            st.image(image, caption=f"Resized: {w}x{h}")
        except Exception as e:
            st.error(f"Error opening image: {e}")
            st.warning("Please upload a valid image file")
            uploaded_file = None  # Reset uploaded_file so processing stops

with col2:
    st.header("Predictions")
    if uploaded_file is not None:
        with st.spinner("Computing..."):
            try:
                scores = model.predict(image)
                filtered_results = [
                    (i, p) for i, p in enumerate(scores) if p >= confidence_threshold
                ]
                filtered_results.sort(key=lambda x: x[1], reverse=True)

                if filtered_results:
                    get_category = lambda ti: (
                        "Rating" if ti < 4 else "General" if ti < 8856 else "Character"
                    )
                    df = pd.DataFrame(
                        [
                            (i, tags[ti][0], f"{p:.4f}", get_category(ti), tags[ti][1])
                            for i, (ti, p) in enumerate(filtered_results[:200], 1)
                        ],
                        columns=[
                            "Rank",
                            "Label",
                            "Score",
                            "Category",
                            "Dataset frequency",
                        ],
                    )
                    st.dataframe(
                        df,
                        hide_index=True,
                        column_config={
                            "Dataset frequency": st.column_config.NumberColumn(
                                format="localized"
                            )
                        },
                    )
                else:
                    st.info("No labels meet the current threshold.")

            except Exception as e:
                st.error("An error occurred during prediction or processing:")
                st.exception(e)  # Shows the full traceback

    else:
        st.info("Upload an image using the panel on the left to see predictions.")