Spaces:
Sleeping
Sleeping
File size: 13,583 Bytes
8d5a81b 422860c 8d5a81b 422860c 8d5a81b 422860c 8d5a81b 422860c 8d5a81b 422860c 8d5a81b 422860c 8d5a81b 422860c 8d5a81b 422860c 8d5a81b 422860c 8d5a81b 422860c 8d5a81b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 | import io
import base64
import json
import numpy as np
import onnxruntime as ort
from pathlib import Path
from PIL import Image, ImageFilter
from tokenizers import Tokenizer
from fastapi import FastAPI
from pydantic import BaseModel
MODELS_DIR = Path("models")
app = FastAPI()
def make_session(path):
opts = ort.SessionOptions()
opts.intra_op_num_threads = 4
opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
return ort.InferenceSession(str(path), sess_options=opts, providers=["CPUExecutionProvider"])
vis = make_session(MODELS_DIR / "clip_visual.onnx")
txt_sess = make_session(MODELS_DIR / "clip_text.onnx")
tok = Tokenizer.from_file(str(MODELS_DIR / "tokenizer.json"))
def preprocess(img):
img = img.convert("RGB").filter(ImageFilter.MedianFilter(size=3))
img = img.resize((224, 224), Image.BICUBIC)
arr = np.array(img, dtype=np.float32) / 255.0
arr = (arr - [0.48145466, 0.4578275, 0.40821073]) / [0.26862954, 0.26130258, 0.27577711]
return arr.transpose(2, 0, 1)[np.newaxis].astype(np.float32)
def norm(x):
return x / (np.linalg.norm(x, axis=-1, keepdims=True) + 1e-8)
def encode_txt(texts):
SOT, EOT, CTX = 49406, 49407, 77
ids = np.zeros((len(texts), CTX), dtype=np.int64)
for i, t in enumerate(texts):
enc = tok.encode(t.lower()).ids
row = [SOT] + enc + [EOT]
ids[i, :min(len(row), CTX)] = row[:CTX]
return norm(txt_sess.run(None, {txt_sess.get_inputs()[0].name: ids})[0])
PROMPTS = {
"bicycles": (
["a bicycle parked on the street", "a bicycle wheel close up", "bicycle frame and handlebars",
"people riding bicycles on road", "a mountain bike", "a road bicycle", "bicycle rack with bikes",
"a bike leaning against wall", "bicycle tires on pavement"],
["grass only", "a flower garden", "a plain building wall", "empty road no vehicle",
"sky and clouds", "a car on road", "a motorcycle", "a tree trunk"]
),
"bicycle": (
["a bicycle", "bicycle wheel", "bicycle handlebar", "a parked bike",
"bicycle frame", "a person riding a bike", "bicycle seat and pedals"],
["grass", "a flower", "a building wall", "empty ground", "a car", "a motorcycle"]
),
"cars": (
["a car on the road", "a parked car", "car headlights at night", "car door and window",
"a sedan car", "an SUV on the street", "car bumper and grille", "car hood and windshield",
"a vehicle driving on highway", "cars in traffic", "car rear with taillights"],
["a bicycle", "grass field", "a building facade", "sky only", "a tree",
"a bus", "a truck", "a motorcycle", "sidewalk with no cars"]
),
"car": (
["a car", "a vehicle on road", "car headlights", "car door",
"car windshield", "a parked automobile", "car body metal"],
["a bicycle", "grass", "a building", "sky", "a bus", "a truck"]
),
"traffic lights": (
["a traffic light pole on street", "red traffic light signal", "green traffic light signal",
"yellow traffic light", "traffic signal at intersection", "traffic light hanging above road",
"a stoplight on pole", "pedestrian traffic signal light"],
["a car", "grass", "a building wall", "sky without lights", "a tree",
"a street lamp", "a billboard", "a road sign"]
),
"traffic light": (
["a traffic light", "traffic signal pole", "red green traffic light",
"stoplight at intersection", "a traffic signal"],
["a car", "grass", "a building", "sky", "a street lamp", "a road sign"]
),
"fire hydrants": (
["a fire hydrant on sidewalk", "a red fire hydrant", "a yellow fire hydrant",
"fire hydrant near curb", "a standpipe hydrant on street",
"a short red cylinder hydrant", "fire hydrant bolts on top"],
["a car", "grass", "a building wall", "sky", "a tree",
"a parking meter", "a trash can", "a mailbox"]
),
"fire hydrant": (
["a fire hydrant", "a red hydrant", "fire hydrant on sidewalk",
"a short red yellow cylinder on street"],
["a car", "grass", "a building", "sky", "a parking meter"]
),
"buses": (
["a city bus on the road", "a public transit bus", "a large passenger bus",
"a school bus", "a double decker bus", "bus exterior side view",
"a bus at a bus stop", "bus windows in a row", "a coach bus on highway"],
["a car", "a bicycle", "grass", "a building", "sky",
"a truck", "a van", "a train"]
),
"bus": (
["a bus", "a public bus", "large bus vehicle", "a city bus",
"bus exterior", "a school bus"],
["a car", "a bicycle", "grass", "a building", "a truck"]
),
"motorcycles": (
["a motorcycle on the road", "a person riding a motorcycle", "motorcycle wheel and engine",
"a parked motorcycle", "motorcycle handlebars and fuel tank",
"a motorbike on street", "a scooter motorcycle", "motorcycle exhaust pipe"],
["grass", "a flower", "a building wall", "sky", "a tree",
"a bicycle", "a car", "a truck"]
),
"motorcycle": (
["a motorcycle", "motorcycle wheel", "riding a motorcycle",
"a motorbike", "motorcycle engine", "a scooter"],
["grass", "a flower", "a building", "sky", "a bicycle", "a car"]
),
"crosswalks": (
["a crosswalk on the road", "zebra crossing white stripes", "pedestrian crossing painted lines",
"white parallel lines on road", "a marked crosswalk at intersection",
"crosswalk stripes on asphalt", "pedestrian walkway markings"],
["a car", "grass", "a building wall", "sky", "a tree",
"a solid road surface", "a sidewalk", "a driveway"]
),
"crosswalk": (
["a crosswalk", "zebra crossing", "pedestrian crossing",
"white stripes on road", "crosswalk lines painted on asphalt"],
["a car", "grass", "a building", "sky", "plain road no markings"]
),
"stairs": (
["stairs going up outdoors", "concrete staircase steps", "outdoor stone steps",
"a staircase with railing", "steps leading to building entrance",
"stair steps close up", "wooden staircase interior"],
["grass", "a tree", "sky", "a car", "a window", "flat ground", "a ramp"]
),
"staircase": (
["a staircase", "stairs", "steps going up", "stair railing and steps"],
["grass", "a tree", "sky", "a car", "flat surface"]
),
"chimneys": (
["a chimney on a rooftop", "brick chimney stack", "chimney on top of building",
"a tall chimney pipe", "industrial chimney", "multiple chimneys on roof"],
["grass", "a car", "sky only", "a tree", "a road", "a wall", "a window"]
),
"bridges": (
["a bridge over water", "a road bridge spanning river", "bridge structure with supports",
"a suspension bridge", "a concrete bridge", "bridge arch over water",
"a pedestrian bridge", "bridge girders and cables"],
["grass", "a car", "a building", "a tree", "a road without bridge"]
),
"boats": (
["a boat on water", "a sailing boat", "a motorboat", "a ship at sea",
"a rowboat on lake", "a fishing boat", "boat hull in water",
"a yacht on ocean", "a ferry boat"],
["grass", "a car", "a building", "a tree", "a road", "empty water no boat"]
),
"mountains": (
["a mountain landscape", "mountain peak with snow", "rocky mountain scenery",
"a mountain range in background", "mountain slope with trees",
"high altitude mountain view", "mountain ridge and valley"],
["a car", "a building", "a road", "a bicycle", "flat ground", "a city skyline"]
),
"tractors": (
["a farm tractor", "a tractor in a field", "agricultural tractor working",
"tractor large rear wheels", "a green farm tractor", "tractor on farmland"],
["a car", "grass without tractor", "a building", "sky", "a bicycle", "a truck"]
),
"parking meters": (
["a parking meter on sidewalk", "coin operated parking meter",
"a metal parking meter pole", "parking pay station on street",
"a single post parking meter"],
["a car", "grass", "a building", "sky", "a tree", "a fire hydrant", "a trash can"]
),
"trucks": (
["a large truck on the road", "a delivery truck", "a semi truck with trailer",
"a cargo truck", "truck cab and body", "a pickup truck",
"a freight truck on highway", "truck wheels and axle"],
["a car", "a bicycle", "grass", "a building", "sky", "a bus"]
),
"truck": (
["a truck", "a delivery truck", "a pickup truck", "cargo truck body"],
["a car", "a bicycle", "grass", "a building", "a bus"]
),
"palm trees": (
["a palm tree", "tropical palm tree leaves", "a tall palm trunk",
"coconut palm tree", "palm fronds at top of tree", "a palm tree on beach"],
["a car", "a building", "grass", "a pine tree", "a leafy tree", "a cactus"]
),
"traffic signs": (
["a traffic sign on pole", "a road sign", "a stop sign", "a yield sign",
"speed limit sign on road", "a warning road sign", "directional traffic sign"],
["a car", "grass", "a building", "sky", "a tree", "a traffic light"]
),
"vehicles": (
["a motor vehicle on road", "a car driving", "a bus on street",
"a truck on highway", "a motorcycle", "a vehicle in traffic"],
["grass", "a building", "sky", "a tree", "a bicycle", "a person walking"]
),
"airplanes": (
["an airplane in the sky", "a commercial aircraft", "airplane wings in flight",
"a plane on runway", "aircraft fuselage", "a jet plane taking off"],
["a car", "a bird", "a building", "grass", "a boat", "clouds only"]
),
"train": (
["a train on tracks", "a locomotive", "train cars on railway",
"a passenger train", "train wheels on rails"],
["a car", "a bus", "a truck", "grass", "a building", "a road"]
),
"taxicabs": (
["a yellow taxi cab", "a taxicab on road", "a taxi car with sign on top",
"a cab vehicle for hire", "taxi with yellow paint"],
["a private car", "a bus", "a police car", "grass", "a building"]
),
"store fronts": (
["a store front with windows", "a shop entrance facade",
"retail store exterior", "a business storefront with sign",
"shop window display on street"],
["a car", "grass", "sky", "a tree", "a house", "a warehouse"]
),
"taxis": (
["a taxi cab", "a yellow taxi", "a cab with taxi sign",
"a taxi vehicle on street"],
["a private car", "a bus", "grass", "a building"]
),
}
_txt_cache = {}
def get_txt_feats(label):
if label not in _txt_cache:
if label in PROMPTS:
pos, neg = PROMPTS[label]
else:
# generic fallback lebih kaya
pos = [
f"a photo of {label}",
f"{label} close up",
f"an image clearly showing {label}",
f"{label} on the street",
f"a clear view of {label}",
]
neg = [
"grass and dirt",
"a plain building facade",
"sky and clouds only",
"a tree with leaves",
"an empty road surface",
"blurry background texture",
]
_txt_cache[label] = (encode_txt(pos + neg), len(pos))
return _txt_cache[label]
def adaptive_threshold(scores: list[float], n_tiles: int) -> float:
arr = np.array(scores)
mean_s = float(np.mean(arr))
std_s = float(np.std(arr))
max_s = float(np.max(arr))
min_s = float(np.min(arr))
spread = max_s - min_s
if std_s < 0.005:
# semua score mirip: ambil top-N paling tinggi
n_take = max(1, min(3, n_tiles // 3))
return float(sorted(arr)[-n_take])
if spread > 0.15:
# ada gap besar: ambil yang jelas-jelas di atas
return mean_s + 0.5 * std_s
# normal case: agak konservatif
return mean_s + 0.25 * std_s
class ScoreRequest(BaseModel):
label: str
tiles: list[str]
class ScoreResponse(BaseModel):
scores: list[float]
threshold: float
to_click: list[int]
@app.get("/")
def root():
return {"status": "ok"}
@app.get("/health")
def health():
return {"status": "ok"}
@app.post("/score", response_model=ScoreResponse)
def score_tiles(req: ScoreRequest):
label = req.label.lower().strip()
t_feat, n_pos = get_txt_feats(label)
imgs = []
for b64 in req.tiles:
raw = base64.b64decode(b64)
img = Image.open(io.BytesIO(raw))
imgs.append(preprocess(img))
batch = np.concatenate(imgs, axis=0)
i_feat = norm(vis.run(None, {vis.get_inputs()[0].name: batch})[0])
sims = i_feat @ t_feat.T
scores = [float(sims[i, :n_pos].max() - sims[i, n_pos:].max()) for i in range(len(imgs))]
threshold = adaptive_threshold(scores, len(imgs))
to_click = [i for i, s in enumerate(scores) if s >= threshold]
# safety: kalau terlalu banyak klik (>= semua tile) mungkin threshold terlalu rendah, naikkan
if len(to_click) >= len(scores):
threshold = float(np.max(scores)) * 0.95
to_click = [i for i, s in enumerate(scores) if s >= threshold]
return ScoreResponse(scores=scores, threshold=threshold, to_click=to_click)
|