Spaces:
Runtime error
Runtime error
Deploy auto GPU fallback + FastAPI /predict
Browse files- .gitattributes +35 -35
- README.md +14 -14
- app.py +464 -423
- requirements.txt +36 -36
- scripts/tmos_classifier.py +216 -216
.gitattributes
CHANGED
|
@@ -1,35 +1,35 @@
|
|
| 1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,14 +1,14 @@
|
|
| 1 |
-
---
|
| 2 |
-
title: TrueFrame
|
| 3 |
-
emoji: π
|
| 4 |
-
colorFrom: yellow
|
| 5 |
-
colorTo: indigo
|
| 6 |
-
sdk: gradio
|
| 7 |
-
sdk_version: 6.11.0
|
| 8 |
-
app_file: app.py
|
| 9 |
-
pinned: false
|
| 10 |
-
license: apache-2.0
|
| 11 |
-
short_description: A LLaVA-based multimodal classifier with LoRA fine-tuning, d
|
| 12 |
-
---
|
| 13 |
-
|
| 14 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: TrueFrame
|
| 3 |
+
emoji: π
|
| 4 |
+
colorFrom: yellow
|
| 5 |
+
colorTo: indigo
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 6.11.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: apache-2.0
|
| 11 |
+
short_description: A LLaVA-based multimodal classifier with LoRA fine-tuning, d
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
CHANGED
|
@@ -1,423 +1,464 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import sys
|
| 3 |
-
import time
|
| 4 |
-
import math
|
| 5 |
-
import json
|
| 6 |
-
|
| 7 |
-
from
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
import
|
| 11 |
-
|
| 12 |
-
from
|
| 13 |
-
from
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
if
|
| 64 |
-
continue
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
if
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
model
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
if
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
return
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
<
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import time
|
| 4 |
+
import math
|
| 5 |
+
import json
|
| 6 |
+
import io
|
| 7 |
+
from contextlib import nullcontext
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
import gradio as gr
|
| 11 |
+
import torch
|
| 12 |
+
from fastapi import FastAPI, UploadFile, File
|
| 13 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 14 |
+
from fastapi.responses import JSONResponse
|
| 15 |
+
from dotenv import load_dotenv
|
| 16 |
+
from PIL import Image, ImageOps
|
| 17 |
+
from transformers import AutoProcessor, AutoImageProcessor, AutoModelForImageClassification
|
| 18 |
+
|
| 19 |
+
ROOT_DIR = Path(__file__).resolve().parent
|
| 20 |
+
SCRIPTS_DIR = ROOT_DIR / "scripts"
|
| 21 |
+
|
| 22 |
+
if str(SCRIPTS_DIR) not in sys.path:
|
| 23 |
+
sys.path.insert(0, str(SCRIPTS_DIR))
|
| 24 |
+
|
| 25 |
+
load_dotenv()
|
| 26 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 27 |
+
|
| 28 |
+
BASE_MODEL_ID = "llava-hf/llava-1.5-7b-hf"
|
| 29 |
+
ADAPTER_PATH = ROOT_DIR / "final-production-weights" / "best_model"
|
| 30 |
+
ADAPTER_REPO_ID = os.getenv("ADAPTER_REPO_ID", "Werrewulf/TMOS-DD")
|
| 31 |
+
ADAPTER_SUBFOLDER = os.getenv("ADAPTER_SUBFOLDER", "")
|
| 32 |
+
|
| 33 |
+
CPU_FALLBACK_MODEL_ID = os.getenv("CPU_FALLBACK_MODEL_ID", "DaMsTaR/Detecto-DeepFake_Image_Detector")
|
| 34 |
+
DEFAULT_INVERT_FALLBACK = CPU_FALLBACK_MODEL_ID.lower() == "damstar/detecto-deepfake_image_detector"
|
| 35 |
+
INVERT_FALLBACK_OUTPUT = os.getenv("INVERT_FALLBACK_OUTPUT", str(DEFAULT_INVERT_FALLBACK)).strip().lower() == "true"
|
| 36 |
+
|
| 37 |
+
TMOS_PROMPT = "USER: <image>\nIs this video real or produced by AI?\nASSISTANT:"
|
| 38 |
+
TARGET_IMAGE_SIZE = 336
|
| 39 |
+
THRESHOLD = 0.5
|
| 40 |
+
|
| 41 |
+
model = None
|
| 42 |
+
processor = None
|
| 43 |
+
inference_device = None
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def resolve_inference_device(model_obj) -> torch.device:
|
| 47 |
+
if torch.cuda.is_available():
|
| 48 |
+
return torch.device("cuda")
|
| 49 |
+
|
| 50 |
+
device_map = getattr(model_obj, "hf_device_map", None)
|
| 51 |
+
if isinstance(device_map, dict):
|
| 52 |
+
for mapped in device_map.values():
|
| 53 |
+
if isinstance(mapped, str) and mapped.startswith("cuda"):
|
| 54 |
+
return torch.device(mapped)
|
| 55 |
+
return torch.device("cpu")
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def find_classifier_weight_tensor(model_obj):
|
| 59 |
+
visited = set()
|
| 60 |
+
queue = [model_obj]
|
| 61 |
+
while queue:
|
| 62 |
+
current = queue.pop(0)
|
| 63 |
+
if current is None:
|
| 64 |
+
continue
|
| 65 |
+
obj_id = id(current)
|
| 66 |
+
if obj_id in visited:
|
| 67 |
+
continue
|
| 68 |
+
visited.add(obj_id)
|
| 69 |
+
|
| 70 |
+
classifier = getattr(current, "classifier", None)
|
| 71 |
+
if classifier is not None and hasattr(classifier, "weight"):
|
| 72 |
+
return classifier.weight
|
| 73 |
+
|
| 74 |
+
for attr in ("model", "base_model", "module"):
|
| 75 |
+
nested = getattr(current, attr, None)
|
| 76 |
+
if nested is not None:
|
| 77 |
+
queue.append(nested)
|
| 78 |
+
|
| 79 |
+
return None
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def count_lora_layers(model_obj) -> int:
|
| 83 |
+
count = 0
|
| 84 |
+
for _, module in model_obj.named_modules():
|
| 85 |
+
if hasattr(module, "lora_A") and hasattr(module, "lora_B"):
|
| 86 |
+
count += 1
|
| 87 |
+
return count
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def is_tmos_adapter_config(cfg: dict) -> bool:
|
| 91 |
+
modules_to_save = cfg.get("modules_to_save") or []
|
| 92 |
+
target_modules = set(cfg.get("target_modules") or [])
|
| 93 |
+
required_targets = {"q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"}
|
| 94 |
+
|
| 95 |
+
return (
|
| 96 |
+
"classifier" in modules_to_save
|
| 97 |
+
and cfg.get("r") == 64
|
| 98 |
+
and required_targets.issubset(target_modules)
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def load_local_adapter_config(adapter_dir: Path) -> dict | None:
|
| 103 |
+
cfg_path = adapter_dir / "adapter_config.json"
|
| 104 |
+
if not cfg_path.exists():
|
| 105 |
+
return None
|
| 106 |
+
with cfg_path.open("r", encoding="utf-8") as fp:
|
| 107 |
+
return json.load(fp)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def load_remote_adapter_config(repo_id: str, subfolder: str) -> dict | None:
|
| 111 |
+
from peft import PeftConfig
|
| 112 |
+
|
| 113 |
+
try:
|
| 114 |
+
peft_cfg = PeftConfig.from_pretrained(repo_id, subfolder=subfolder, token=HF_TOKEN)
|
| 115 |
+
return peft_cfg.to_dict()
|
| 116 |
+
except Exception:
|
| 117 |
+
return None
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def select_torch_dtype() -> torch.dtype:
|
| 121 |
+
if torch.cuda.is_available():
|
| 122 |
+
return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
| 123 |
+
return torch.float32
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def load_tmos_model():
|
| 127 |
+
global model, processor, inference_device
|
| 128 |
+
|
| 129 |
+
if not torch.cuda.is_available():
|
| 130 |
+
raise RuntimeError(
|
| 131 |
+
"TMOS mode requires GPU hardware. CPU fallback should be used on CPU-only environments."
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
from peft import PeftModel
|
| 135 |
+
from tmos_classifier import TMOSClassifier
|
| 136 |
+
|
| 137 |
+
adapter_source = None
|
| 138 |
+
local_adapter_file = next(
|
| 139 |
+
(
|
| 140 |
+
candidate
|
| 141 |
+
for candidate in (
|
| 142 |
+
ADAPTER_PATH / "adapter_model.safetensors",
|
| 143 |
+
ADAPTER_PATH / "adapter_model.bin",
|
| 144 |
+
)
|
| 145 |
+
if candidate.exists()
|
| 146 |
+
),
|
| 147 |
+
None,
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
selected_subfolder = ""
|
| 151 |
+
|
| 152 |
+
if local_adapter_file is not None:
|
| 153 |
+
adapter_source = str(ADAPTER_PATH)
|
| 154 |
+
local_cfg = load_local_adapter_config(ADAPTER_PATH)
|
| 155 |
+
if local_cfg is None or not is_tmos_adapter_config(local_cfg):
|
| 156 |
+
raise RuntimeError(
|
| 157 |
+
"Local adapter exists but is not TMOS-compatible. Expected modules_to_save=['classifier'], r=64, and TMOS target modules."
|
| 158 |
+
)
|
| 159 |
+
else:
|
| 160 |
+
adapter_source = ADAPTER_REPO_ID
|
| 161 |
+
|
| 162 |
+
dtype = select_torch_dtype()
|
| 163 |
+
print(f"Loading TMOS-DD model from {adapter_source} with dtype={dtype}...")
|
| 164 |
+
|
| 165 |
+
base_model = TMOSClassifier(
|
| 166 |
+
base_model_id=BASE_MODEL_ID,
|
| 167 |
+
torch_dtype=dtype,
|
| 168 |
+
device_map="auto",
|
| 169 |
+
token=HF_TOKEN,
|
| 170 |
+
)
|
| 171 |
+
base_classifier_weight = find_classifier_weight_tensor(base_model)
|
| 172 |
+
base_classifier_snapshot = None
|
| 173 |
+
if base_classifier_weight is not None:
|
| 174 |
+
base_classifier_snapshot = base_classifier_weight.detach().float().cpu().clone()
|
| 175 |
+
|
| 176 |
+
peft_kwargs = {"is_trainable": False, "token": HF_TOKEN}
|
| 177 |
+
if adapter_source == ADAPTER_REPO_ID:
|
| 178 |
+
candidate_subfolders = [
|
| 179 |
+
s for s in [ADAPTER_SUBFOLDER, "multimodal", "multimodal/checkpoint-5", "llava"] if s is not None
|
| 180 |
+
]
|
| 181 |
+
|
| 182 |
+
last_error = None
|
| 183 |
+
for subfolder in candidate_subfolders:
|
| 184 |
+
try:
|
| 185 |
+
remote_cfg = load_remote_adapter_config(adapter_source, subfolder)
|
| 186 |
+
if remote_cfg is None or not is_tmos_adapter_config(remote_cfg):
|
| 187 |
+
raise ValueError("Adapter config is not TMOS-compatible.")
|
| 188 |
+
|
| 189 |
+
current_kwargs = dict(peft_kwargs)
|
| 190 |
+
if subfolder:
|
| 191 |
+
current_kwargs["subfolder"] = subfolder
|
| 192 |
+
loaded_model = PeftModel.from_pretrained(base_model, adapter_source, **current_kwargs)
|
| 193 |
+
|
| 194 |
+
lora_layer_count = count_lora_layers(loaded_model)
|
| 195 |
+
if lora_layer_count == 0:
|
| 196 |
+
raise RuntimeError("Loaded adapter has zero LoRA layers attached.")
|
| 197 |
+
|
| 198 |
+
loaded_classifier_weight = find_classifier_weight_tensor(loaded_model)
|
| 199 |
+
if loaded_classifier_weight is None:
|
| 200 |
+
raise RuntimeError("Classifier head not found after adapter load.")
|
| 201 |
+
|
| 202 |
+
if base_classifier_snapshot is not None:
|
| 203 |
+
classifier_delta = (
|
| 204 |
+
loaded_classifier_weight.detach().float().cpu() - base_classifier_snapshot
|
| 205 |
+
).abs().mean().item()
|
| 206 |
+
if classifier_delta < 1e-8:
|
| 207 |
+
raise RuntimeError(
|
| 208 |
+
"Classifier weights did not change after loading adapter; adapter likely incompatible."
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
model = loaded_model.merge_and_unload()
|
| 212 |
+
selected_subfolder = subfolder
|
| 213 |
+
print(
|
| 214 |
+
f"Loaded TMOS adapter from repo subfolder: '{subfolder or '.'}' "
|
| 215 |
+
f"(lora_layers={lora_layer_count})"
|
| 216 |
+
)
|
| 217 |
+
break
|
| 218 |
+
except Exception as exc:
|
| 219 |
+
last_error = exc
|
| 220 |
+
continue
|
| 221 |
+
else:
|
| 222 |
+
raise RuntimeError(
|
| 223 |
+
"No TMOS-compatible adapter found in remote repo. Upload TMOS production weights with classifier head "
|
| 224 |
+
"(modules_to_save=['classifier'], r=64, 7-target-module LoRA)."
|
| 225 |
+
) from last_error
|
| 226 |
+
else:
|
| 227 |
+
loaded_model = PeftModel.from_pretrained(base_model, adapter_source, **peft_kwargs)
|
| 228 |
+
lora_layer_count = count_lora_layers(loaded_model)
|
| 229 |
+
if lora_layer_count == 0:
|
| 230 |
+
raise RuntimeError("Local adapter load produced zero LoRA layers attached.")
|
| 231 |
+
model = loaded_model.merge_and_unload()
|
| 232 |
+
print(f"Loaded TMOS local adapter (lora_layers={lora_layer_count})")
|
| 233 |
+
|
| 234 |
+
model.eval()
|
| 235 |
+
processor = AutoProcessor.from_pretrained(BASE_MODEL_ID, token=HF_TOKEN)
|
| 236 |
+
processor.patch_size = 14
|
| 237 |
+
processor.vision_feature_select_strategy = "default"
|
| 238 |
+
inference_device = resolve_inference_device(model)
|
| 239 |
+
if adapter_source == ADAPTER_REPO_ID:
|
| 240 |
+
print(f"TMOS-DD ready on {inference_device} using remote subfolder '{selected_subfolder or '.'}'.")
|
| 241 |
+
else:
|
| 242 |
+
print(f"TMOS-DD ready on {inference_device} using local production adapter.")
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def load_cpu_fallback_model():
|
| 246 |
+
global model, processor, inference_device
|
| 247 |
+
print(f"Loading CPU fallback model from {CPU_FALLBACK_MODEL_ID}...")
|
| 248 |
+
|
| 249 |
+
processor = AutoImageProcessor.from_pretrained(CPU_FALLBACK_MODEL_ID, token=HF_TOKEN)
|
| 250 |
+
model = AutoModelForImageClassification.from_pretrained(
|
| 251 |
+
CPU_FALLBACK_MODEL_ID,
|
| 252 |
+
torch_dtype=torch.float32,
|
| 253 |
+
low_cpu_mem_usage=True,
|
| 254 |
+
token=HF_TOKEN,
|
| 255 |
+
)
|
| 256 |
+
model.to("cpu").eval()
|
| 257 |
+
inference_device = torch.device("cpu")
|
| 258 |
+
print("CPU fallback classifier ready.")
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def load_model_and_processor():
|
| 262 |
+
global model, processor, inference_device
|
| 263 |
+
|
| 264 |
+
if model is not None and processor is not None and inference_device is not None:
|
| 265 |
+
return model, processor, inference_device
|
| 266 |
+
|
| 267 |
+
if torch.cuda.is_available():
|
| 268 |
+
print("GPU detected -> loading TMOS")
|
| 269 |
+
try:
|
| 270 |
+
load_tmos_model()
|
| 271 |
+
except Exception as exc:
|
| 272 |
+
print(f"TMOS failed: {exc}")
|
| 273 |
+
print("Falling back to CPU model...")
|
| 274 |
+
load_cpu_fallback_model()
|
| 275 |
+
else:
|
| 276 |
+
print("No GPU detected -> using CPU fallback")
|
| 277 |
+
load_cpu_fallback_model()
|
| 278 |
+
|
| 279 |
+
return model, processor, inference_device
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def preprocess_image(image: Image.Image) -> Image.Image:
|
| 283 |
+
image = image.convert("RGB")
|
| 284 |
+
return ImageOps.contain(image, (TARGET_IMAGE_SIZE, TARGET_IMAGE_SIZE), method=Image.Resampling.BICUBIC)
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def confidence_card(prob_fake: float, label: str) -> str:
|
| 288 |
+
confidence = prob_fake if label == "Fake" else 1.0 - prob_fake
|
| 289 |
+
confidence_pct = confidence * 100.0
|
| 290 |
+
fake_pct = prob_fake * 100.0
|
| 291 |
+
real_pct = (1.0 - prob_fake) * 100.0
|
| 292 |
+
accent = "#ef4444" if label == "Fake" else "#10b981"
|
| 293 |
+
|
| 294 |
+
return f"""
|
| 295 |
+
<div style="border:1px solid rgba(255,255,255,0.12); border-radius:16px; padding:16px; background:linear-gradient(135deg, rgba(17,24,39,0.92), rgba(15,23,42,0.96)); color:white;">
|
| 296 |
+
<div style="font-size:0.85rem; opacity:0.8; letter-spacing:0.04em; text-transform:uppercase; margin-bottom:8px;">Confidence</div>
|
| 297 |
+
<div style="display:flex; align-items:baseline; gap:10px; margin-bottom:12px;">
|
| 298 |
+
<div style="font-size:2rem; font-weight:700; color:{accent};">{confidence_pct:.2f}%</div>
|
| 299 |
+
<div style="font-size:1rem; opacity:0.9;">for <strong>{label}</strong></div>
|
| 300 |
+
</div>
|
| 301 |
+
<div style="height:12px; width:100%; background:rgba(255,255,255,0.08); border-radius:999px; overflow:hidden; margin-bottom:10px;">
|
| 302 |
+
<div style="height:100%; width:{fake_pct:.2f}%; background:linear-gradient(90deg, #f87171, #ef4444);"></div>
|
| 303 |
+
</div>
|
| 304 |
+
<div style="display:flex; justify-content:space-between; font-size:0.9rem; opacity:0.95;">
|
| 305 |
+
<span>Real: {real_pct:.2f}%</span>
|
| 306 |
+
<span>Fake: {fake_pct:.2f}%</span>
|
| 307 |
+
</div>
|
| 308 |
+
</div>
|
| 309 |
+
"""
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def score_fallback_logits(logits: torch.Tensor, id2label: dict) -> tuple[float, str]:
|
| 313 |
+
probs = torch.softmax(logits.float(), dim=0)
|
| 314 |
+
|
| 315 |
+
fake_indices = []
|
| 316 |
+
real_indices = []
|
| 317 |
+
for idx in range(len(probs)):
|
| 318 |
+
label = str(id2label.get(idx, "")).lower()
|
| 319 |
+
if any(key in label for key in ["fake", "deepfake", "ai", "synthetic"]):
|
| 320 |
+
fake_indices.append(idx)
|
| 321 |
+
if any(key in label for key in ["real", "authentic", "genuine"]):
|
| 322 |
+
real_indices.append(idx)
|
| 323 |
+
|
| 324 |
+
if len(probs) == 2 and not fake_indices and not real_indices:
|
| 325 |
+
fake_indices = [1]
|
| 326 |
+
real_indices = [0]
|
| 327 |
+
|
| 328 |
+
fake_prob = float(probs[fake_indices].sum().item()) if fake_indices else 0.0
|
| 329 |
+
real_prob = float(probs[real_indices].sum().item()) if real_indices else 0.0
|
| 330 |
+
|
| 331 |
+
total = fake_prob + real_prob
|
| 332 |
+
if total > 0:
|
| 333 |
+
prob_fake = fake_prob / total
|
| 334 |
+
else:
|
| 335 |
+
prob_fake = float(probs.max().item()) if len(probs) == 1 else float(probs[1].item()) if len(probs) > 1 else 0.5
|
| 336 |
+
|
| 337 |
+
if INVERT_FALLBACK_OUTPUT:
|
| 338 |
+
prob_fake = 1.0 - prob_fake
|
| 339 |
+
|
| 340 |
+
label = "Fake" if prob_fake >= THRESHOLD else "Real"
|
| 341 |
+
return prob_fake, label
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
def infer_image(image: Image.Image):
|
| 345 |
+
try:
|
| 346 |
+
if image is None:
|
| 347 |
+
return None, "Error: please upload an image.", None, None, None, "<div style='color:#f87171;'>Please upload an image before running detection.</div>"
|
| 348 |
+
|
| 349 |
+
model_obj, processor_obj, device = load_model_and_processor()
|
| 350 |
+
prepared_image = preprocess_image(image)
|
| 351 |
+
|
| 352 |
+
autocast_context = (
|
| 353 |
+
torch.autocast(device_type="cuda", dtype=select_torch_dtype())
|
| 354 |
+
if device.type == "cuda"
|
| 355 |
+
else nullcontext()
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
start_time = time.perf_counter()
|
| 359 |
+
with torch.inference_mode(), autocast_context:
|
| 360 |
+
if inference_device.type == "cuda":
|
| 361 |
+
inputs = processor_obj(text=TMOS_PROMPT, images=prepared_image, return_tensors="pt", padding=True)
|
| 362 |
+
inputs = {name: tensor.to(device) for name, tensor in inputs.items()}
|
| 363 |
+
outputs = model_obj(
|
| 364 |
+
input_ids=inputs["input_ids"],
|
| 365 |
+
pixel_values=inputs["pixel_values"],
|
| 366 |
+
attention_mask=inputs["attention_mask"],
|
| 367 |
+
)
|
| 368 |
+
logit = float(outputs["logit"].squeeze().detach().float().cpu().item())
|
| 369 |
+
if not math.isfinite(logit):
|
| 370 |
+
raise gr.Error("Model produced a non-finite logit (NaN/Inf). Please retry.")
|
| 371 |
+
prob_fake = float(torch.sigmoid(torch.tensor(logit)).item())
|
| 372 |
+
label = "Fake" if prob_fake >= THRESHOLD else "Real"
|
| 373 |
+
else:
|
| 374 |
+
inputs = processor_obj(images=prepared_image, return_tensors="pt")
|
| 375 |
+
inputs = {name: tensor.to(device) for name, tensor in inputs.items()}
|
| 376 |
+
outputs = model_obj(**inputs)
|
| 377 |
+
logits = outputs.logits.squeeze(0).detach().float().cpu()
|
| 378 |
+
id2label = getattr(model_obj.config, "id2label", {}) or {}
|
| 379 |
+
prob_fake, label = score_fallback_logits(logits, id2label)
|
| 380 |
+
|
| 381 |
+
if device.type == "cuda":
|
| 382 |
+
torch.cuda.synchronize()
|
| 383 |
+
|
| 384 |
+
elapsed_ms = (time.perf_counter() - start_time) * 1000.0
|
| 385 |
+
if not math.isfinite(prob_fake):
|
| 386 |
+
raise gr.Error("Model produced a non-finite probability (NaN/Inf). Please retry.")
|
| 387 |
+
|
| 388 |
+
confidence = prob_fake if label == "Fake" else 1.0 - prob_fake
|
| 389 |
+
return prepared_image, label, round(prob_fake, 6), round(confidence * 100.0, 2), round(elapsed_ms, 2), confidence_card(prob_fake, label)
|
| 390 |
+
except Exception as exc:
|
| 391 |
+
err = f"Inference failed: {type(exc).__name__}: {exc}"
|
| 392 |
+
err_html = f"<div style='color:#fca5a5; border:1px solid rgba(252,165,165,0.35); padding:10px; border-radius:10px;'>\n<b>Inference error</b><br>{err}</div>"
|
| 393 |
+
return None, err, None, None, None, err_html
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
api = FastAPI()
|
| 398 |
+
|
| 399 |
+
api.add_middleware(
|
| 400 |
+
CORSMiddleware,
|
| 401 |
+
allow_origins=["*"],
|
| 402 |
+
allow_credentials=True,
|
| 403 |
+
allow_methods=["*"],
|
| 404 |
+
allow_headers=["*"],
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
@api.post("/predict")
|
| 409 |
+
async def predict(file: UploadFile = File(...)):
|
| 410 |
+
try:
|
| 411 |
+
contents = await file.read()
|
| 412 |
+
image = Image.open(io.BytesIO(contents)).convert("RGB")
|
| 413 |
+
|
| 414 |
+
_, label, prob_fake, confidence, latency, _ = infer_image(image)
|
| 415 |
+
|
| 416 |
+
return JSONResponse(
|
| 417 |
+
{
|
| 418 |
+
"verdict": label,
|
| 419 |
+
"confidence_percent": confidence,
|
| 420 |
+
"p_fake": prob_fake,
|
| 421 |
+
"latency_ms": latency,
|
| 422 |
+
}
|
| 423 |
+
)
|
| 424 |
+
except Exception as exc:
|
| 425 |
+
return JSONResponse({"error": str(exc)}, status_code=500)
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
load_model_and_processor()
|
| 429 |
+
|
| 430 |
+
with gr.Blocks(title="TMOS Deepfake Detector", theme=gr.themes.Soft()) as demo:
|
| 431 |
+
device_label = "GPU (TMOS Model)" if torch.cuda.is_available() else "CPU Fallback Model"
|
| 432 |
+
gr.Markdown(
|
| 433 |
+
f"# TMOS Deepfake Detector\n"
|
| 434 |
+
f"**Running on:** {device_label}\n\n"
|
| 435 |
+
f"> Warning: runs on free infrastructure, so startup and inference may take time."
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
with gr.Row():
|
| 439 |
+
image_input = gr.Image(type="pil", label="Upload image")
|
| 440 |
+
with gr.Column():
|
| 441 |
+
prediction_output = gr.Textbox(label="Prediction", interactive=False)
|
| 442 |
+
probability_output = gr.Number(label="P(fake)", interactive=False, precision=6)
|
| 443 |
+
confidence_output = gr.Number(label="Confidence (%)", interactive=False, precision=2)
|
| 444 |
+
latency_output = gr.Number(label="Latency (ms)", interactive=False, precision=2)
|
| 445 |
+
|
| 446 |
+
preview_output = gr.Image(label="Processed image passed to the model", interactive=False)
|
| 447 |
+
confidence_html = gr.HTML()
|
| 448 |
+
|
| 449 |
+
detect_button = gr.Button("Run detection", variant="primary")
|
| 450 |
+
|
| 451 |
+
detect_button.click(
|
| 452 |
+
fn=infer_image,
|
| 453 |
+
inputs=image_input,
|
| 454 |
+
outputs=[preview_output, prediction_output, probability_output, confidence_output, latency_output, confidence_html],
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
demo.queue(default_concurrency_limit=1, max_size=8)
|
| 458 |
+
|
| 459 |
+
app = gr.mount_gradio_app(api, demo, path="/")
|
| 460 |
+
|
| 461 |
+
if __name__ == "__main__":
|
| 462 |
+
import uvicorn
|
| 463 |
+
|
| 464 |
+
uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", "7860")))
|
requirements.txt
CHANGED
|
@@ -1,36 +1,36 @@
|
|
| 1 |
-
# Core ML
|
| 2 |
-
numpy>=1.24.0
|
| 3 |
-
python-dotenv
|
| 4 |
-
gradio
|
| 5 |
-
torch>=2.0.0
|
| 6 |
-
torchvision>=0.15.0
|
| 7 |
-
torchaudio>=2.0.0
|
| 8 |
-
torchcodec
|
| 9 |
-
Pillow
|
| 10 |
-
|
| 11 |
-
# Dependencies
|
| 12 |
-
albumentations>=0.5.2
|
| 13 |
-
datasets
|
| 14 |
-
huggingface_hub
|
| 15 |
-
scikit-learn>=1.3.0
|
| 16 |
-
scikit-image>=0.21.0
|
| 17 |
-
pandas>=2.0.0
|
| 18 |
-
matplotlib>=3.7.0
|
| 19 |
-
seaborn
|
| 20 |
-
transformers==4.36.2
|
| 21 |
-
peft
|
| 22 |
-
accelerate
|
| 23 |
-
diffusers
|
| 24 |
-
opencv-python
|
| 25 |
-
|
| 26 |
-
# Optional fallback for lower-memory GPU execution
|
| 27 |
-
bitsandbytes
|
| 28 |
-
|
| 29 |
-
# M2TR specific
|
| 30 |
-
yacs==0.1.8
|
| 31 |
-
nbconvert
|
| 32 |
-
tensorboard==2.20.0
|
| 33 |
-
tqdm==4.67.1
|
| 34 |
-
PyYAML==6.0.3
|
| 35 |
-
simplejson==3.20.2
|
| 36 |
-
fvcore
|
|
|
|
| 1 |
+
# Core ML
|
| 2 |
+
numpy>=1.24.0
|
| 3 |
+
python-dotenv
|
| 4 |
+
gradio
|
| 5 |
+
torch>=2.0.0
|
| 6 |
+
torchvision>=0.15.0
|
| 7 |
+
torchaudio>=2.0.0
|
| 8 |
+
torchcodec
|
| 9 |
+
Pillow
|
| 10 |
+
|
| 11 |
+
# Dependencies
|
| 12 |
+
albumentations>=0.5.2
|
| 13 |
+
datasets
|
| 14 |
+
huggingface_hub
|
| 15 |
+
scikit-learn>=1.3.0
|
| 16 |
+
scikit-image>=0.21.0
|
| 17 |
+
pandas>=2.0.0
|
| 18 |
+
matplotlib>=3.7.0
|
| 19 |
+
seaborn
|
| 20 |
+
transformers==4.36.2
|
| 21 |
+
peft
|
| 22 |
+
accelerate
|
| 23 |
+
diffusers
|
| 24 |
+
opencv-python
|
| 25 |
+
|
| 26 |
+
# Optional fallback for lower-memory GPU execution
|
| 27 |
+
bitsandbytes
|
| 28 |
+
|
| 29 |
+
# M2TR specific
|
| 30 |
+
yacs==0.1.8
|
| 31 |
+
nbconvert
|
| 32 |
+
tensorboard==2.20.0
|
| 33 |
+
tqdm==4.67.1
|
| 34 |
+
PyYAML==6.0.3
|
| 35 |
+
simplejson==3.20.2
|
| 36 |
+
fvcore
|
scripts/tmos_classifier.py
CHANGED
|
@@ -1,216 +1,216 @@
|
|
| 1 |
-
"""
|
| 2 |
-
TMOS_Classifier: Binary classification head on top of LLaVA's transformer backbone.
|
| 3 |
-
|
| 4 |
-
Strips the autoregressive lm_head and replaces it with a single nn.Linear(hidden_size, 1)
|
| 5 |
-
for binary deepfake detection (0 = Real, 1 = Fake).
|
| 6 |
-
|
| 7 |
-
Usage:
|
| 8 |
-
from tmos_classifier import TMOSClassifier, TMOS_LORA_CONFIG
|
| 9 |
-
|
| 10 |
-
classifier = TMOSClassifier(base_model_id="llava-hf/llava-1.5-7b-hf")
|
| 11 |
-
classifier = get_peft_model(classifier, TMOS_LORA_CONFIG)
|
| 12 |
-
|
| 13 |
-
logit = classifier(input_ids=..., pixel_values=..., attention_mask=...)
|
| 14 |
-
loss = nn.BCEWithLogitsLoss()(logit, label)
|
| 15 |
-
"""
|
| 16 |
-
|
| 17 |
-
import torch
|
| 18 |
-
import torch.nn as nn
|
| 19 |
-
from transformers import LlavaForConditionalGeneration
|
| 20 |
-
from peft import LoraConfig
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
# βββ LoRA Configuration ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 24 |
-
# Massive expansion: r=64 across ALL linear layers in the LLM backbone.
|
| 25 |
-
# We exclude lm_head (we discard it), fc1/fc2/out_proj (CLIP vision),
|
| 26 |
-
# and linear_1/linear_2 (multi-modal projector) from LoRA to keep
|
| 27 |
-
# the vision encoder frozen and only adapt the language transformer.
|
| 28 |
-
|
| 29 |
-
TMOS_LORA_CONFIG = LoraConfig(
|
| 30 |
-
r=64,
|
| 31 |
-
lora_alpha=128, # 2x rank as a common heuristic
|
| 32 |
-
target_modules=[
|
| 33 |
-
"q_proj", "k_proj", "v_proj", "o_proj",
|
| 34 |
-
"gate_proj", "up_proj", "down_proj",
|
| 35 |
-
],
|
| 36 |
-
lora_dropout=0.1,
|
| 37 |
-
bias="none",
|
| 38 |
-
task_type=None, # Custom classifier β not a causal LM
|
| 39 |
-
modules_to_save=["classifier"], # Always train the classification head
|
| 40 |
-
)
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
class TMOSClassifier(nn.Module):
|
| 44 |
-
"""
|
| 45 |
-
Binary classifier built on the LLaVA transformer backbone.
|
| 46 |
-
|
| 47 |
-
Architecture:
|
| 48 |
-
pixel_values βββΊ CLIP Vision Tower βββΊ Multi-Modal Projector βββ
|
| 49 |
-
ββββΊ LLaMA Transformer βββΊ last_hidden_state[:, -1, :] βββΊ classifier βββΊ logit
|
| 50 |
-
input_ids βββΊ Token Embedding ββββββββββββββββββββββββββββββββββ
|
| 51 |
-
|
| 52 |
-
The lm_head is never used. We extract the final token's hidden state
|
| 53 |
-
and pass it through a learned nn.Linear(hidden_size, 1) head.
|
| 54 |
-
"""
|
| 55 |
-
|
| 56 |
-
def __init__(self, base_model_id, torch_dtype=torch.float16, device_map="auto", token=None):
|
| 57 |
-
super().__init__()
|
| 58 |
-
|
| 59 |
-
# Load the full LLaVA model (we need vision tower + projector + LLM)
|
| 60 |
-
self.base = LlavaForConditionalGeneration.from_pretrained(
|
| 61 |
-
base_model_id,
|
| 62 |
-
torch_dtype=torch_dtype,
|
| 63 |
-
low_cpu_mem_usage=True,
|
| 64 |
-
device_map=device_map,
|
| 65 |
-
token=token,
|
| 66 |
-
)
|
| 67 |
-
|
| 68 |
-
hidden_size = self.base.config.text_config.hidden_size # 4096 for 7B
|
| 69 |
-
|
| 70 |
-
# Freeze the lm_head β we won't use it, but freezing prevents
|
| 71 |
-
# wasted gradient computation if PEFT accidentally wraps it.
|
| 72 |
-
for param in self.base.lm_head.parameters():
|
| 73 |
-
param.requires_grad = False
|
| 74 |
-
|
| 75 |
-
# Keep the classifier head in fp32 for numerical stability.
|
| 76 |
-
self.classifier = nn.Linear(hidden_size, 1, dtype=torch.float32)
|
| 77 |
-
nn.init.xavier_uniform_(self.classifier.weight)
|
| 78 |
-
nn.init.zeros_(self.classifier.bias)
|
| 79 |
-
|
| 80 |
-
def forward(
|
| 81 |
-
self,
|
| 82 |
-
input_ids=None,
|
| 83 |
-
pixel_values=None,
|
| 84 |
-
attention_mask=None,
|
| 85 |
-
labels=None, # float tensor of shape (B,) β 0.0=real, 1.0=fake
|
| 86 |
-
**kwargs, # absorb extra keys from data collator
|
| 87 |
-
):
|
| 88 |
-
"""
|
| 89 |
-
Single deterministic forward pass β logit + optional BCE loss.
|
| 90 |
-
|
| 91 |
-
Returns:
|
| 92 |
-
dict with keys:
|
| 93 |
-
"logit": (B, 1) raw logit
|
| 94 |
-
"loss": scalar BCE loss (only if labels provided)
|
| 95 |
-
"""
|
| 96 |
-
# ββ 1. Forward through the LLaVA backbone ββ
|
| 97 |
-
# We call the internal model (vision + projector + LLM) directly,
|
| 98 |
-
# asking for hidden states, NOT for language-model logits.
|
| 99 |
-
outputs = self.base.model(
|
| 100 |
-
input_ids=input_ids,
|
| 101 |
-
pixel_values=pixel_values,
|
| 102 |
-
attention_mask=attention_mask,
|
| 103 |
-
return_dict=True,
|
| 104 |
-
)
|
| 105 |
-
|
| 106 |
-
# last_hidden_state: (B, seq_len, hidden_size)
|
| 107 |
-
last_hidden_state = outputs.last_hidden_state
|
| 108 |
-
|
| 109 |
-
# ββ 2. Pool: extract the final non-padded token per sequence ββ
|
| 110 |
-
if attention_mask is not None:
|
| 111 |
-
# Sum of mask gives the sequence length (excluding padding)
|
| 112 |
-
# Index of the last real token = seq_lengths - 1
|
| 113 |
-
seq_lengths = attention_mask.sum(dim=1).long() - 1
|
| 114 |
-
# Clamp to valid range
|
| 115 |
-
seq_lengths = seq_lengths.clamp(min=0, max=last_hidden_state.size(1) - 1)
|
| 116 |
-
# Gather the hidden state at each sequence's last real token
|
| 117 |
-
pooled = last_hidden_state[
|
| 118 |
-
torch.arange(last_hidden_state.size(0), device=last_hidden_state.device),
|
| 119 |
-
seq_lengths,
|
| 120 |
-
]
|
| 121 |
-
else:
|
| 122 |
-
# No mask β just take the last position
|
| 123 |
-
pooled = last_hidden_state[:, -1, :]
|
| 124 |
-
|
| 125 |
-
# Replace non-finite activations defensively before the classifier.
|
| 126 |
-
pooled = torch.nan_to_num(pooled, nan=0.0, posinf=1e4, neginf=-1e4)
|
| 127 |
-
|
| 128 |
-
# Match classifier device to pooled features when model is sharded/offloaded.
|
| 129 |
-
if self.classifier.weight.device != pooled.device:
|
| 130 |
-
self.classifier = self.classifier.to(pooled.device)
|
| 131 |
-
|
| 132 |
-
# ββ 3. Classify ββ
|
| 133 |
-
logit = self.classifier(pooled.float()) # (B, 1)
|
| 134 |
-
logit = torch.nan_to_num(logit, nan=0.0, posinf=20.0, neginf=-20.0)
|
| 135 |
-
|
| 136 |
-
result = {"logit": logit}
|
| 137 |
-
|
| 138 |
-
# ββ 4. Loss ββ
|
| 139 |
-
if labels is not None:
|
| 140 |
-
labels = labels.to(logit.dtype).to(logit.device)
|
| 141 |
-
if labels.dim() == 1:
|
| 142 |
-
labels = labels.unsqueeze(1) # (B,) β (B, 1)
|
| 143 |
-
loss_fn = nn.BCEWithLogitsLoss()
|
| 144 |
-
result["loss"] = loss_fn(logit, labels)
|
| 145 |
-
|
| 146 |
-
return result
|
| 147 |
-
|
| 148 |
-
def prepare_inputs_for_generation(self, *args, **kwargs):
|
| 149 |
-
"""Stub required by PEFT β we never generate text."""
|
| 150 |
-
raise NotImplementedError("TMOSClassifier does not support generation.")
|
| 151 |
-
|
| 152 |
-
def gradient_checkpointing_enable(self, **kwargs):
|
| 153 |
-
"""Delegate to the base model for HF Trainer compatibility."""
|
| 154 |
-
self.base.model.gradient_checkpointing_enable(**kwargs)
|
| 155 |
-
|
| 156 |
-
@property
|
| 157 |
-
def config(self):
|
| 158 |
-
"""Expose the base model config for PEFT."""
|
| 159 |
-
return self.base.config
|
| 160 |
-
|
| 161 |
-
@property
|
| 162 |
-
def device(self):
|
| 163 |
-
return next(self.parameters()).device
|
| 164 |
-
|
| 165 |
-
@property
|
| 166 |
-
def dtype(self):
|
| 167 |
-
return next(self.parameters()).dtype
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
# βββ Standalone Test ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 171 |
-
if __name__ == "__main__":
|
| 172 |
-
import os
|
| 173 |
-
from dotenv import load_dotenv
|
| 174 |
-
load_dotenv()
|
| 175 |
-
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 176 |
-
|
| 177 |
-
print("Testing TMOSClassifier...")
|
| 178 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 179 |
-
|
| 180 |
-
clf = TMOSClassifier(
|
| 181 |
-
base_model_id="llava-hf/llava-1.5-7b-hf",
|
| 182 |
-
torch_dtype=torch.float16,
|
| 183 |
-
token=HF_TOKEN,
|
| 184 |
-
)
|
| 185 |
-
clf.to(device)
|
| 186 |
-
|
| 187 |
-
# Print parameter counts
|
| 188 |
-
total = sum(p.numel() for p in clf.parameters())
|
| 189 |
-
trainable = sum(p.numel() for p in clf.parameters() if p.requires_grad)
|
| 190 |
-
print(f"Total params: {total:>12,}")
|
| 191 |
-
print(f"Trainable params: {trainable:>12,}")
|
| 192 |
-
print(f"Classifier head: {sum(p.numel() for p in clf.classifier.parameters()):,}")
|
| 193 |
-
|
| 194 |
-
# Smoke test with dummy input
|
| 195 |
-
from transformers import AutoProcessor
|
| 196 |
-
processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf", token=HF_TOKEN)
|
| 197 |
-
processor.patch_size = 14
|
| 198 |
-
processor.vision_feature_select_strategy = "default"
|
| 199 |
-
|
| 200 |
-
from PIL import Image
|
| 201 |
-
dummy_img = Image.new("RGB", (336, 336), color=(128, 128, 128))
|
| 202 |
-
inputs = processor(
|
| 203 |
-
text="USER: <image>\nIs this real?\nASSISTANT:",
|
| 204 |
-
images=dummy_img,
|
| 205 |
-
return_tensors="pt",
|
| 206 |
-
).to(device)
|
| 207 |
-
|
| 208 |
-
labels = torch.tensor([1.0], device=device) # fake
|
| 209 |
-
|
| 210 |
-
with torch.no_grad():
|
| 211 |
-
out = clf(**inputs, labels=labels)
|
| 212 |
-
|
| 213 |
-
print(f"Logit: {out['logit'].item():.4f}")
|
| 214 |
-
print(f"Loss: {out['loss'].item():.4f}")
|
| 215 |
-
print(f"Prob: {torch.sigmoid(out['logit']).item():.4f}")
|
| 216 |
-
print("Test passed.")
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TMOS_Classifier: Binary classification head on top of LLaVA's transformer backbone.
|
| 3 |
+
|
| 4 |
+
Strips the autoregressive lm_head and replaces it with a single nn.Linear(hidden_size, 1)
|
| 5 |
+
for binary deepfake detection (0 = Real, 1 = Fake).
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
from tmos_classifier import TMOSClassifier, TMOS_LORA_CONFIG
|
| 9 |
+
|
| 10 |
+
classifier = TMOSClassifier(base_model_id="llava-hf/llava-1.5-7b-hf")
|
| 11 |
+
classifier = get_peft_model(classifier, TMOS_LORA_CONFIG)
|
| 12 |
+
|
| 13 |
+
logit = classifier(input_ids=..., pixel_values=..., attention_mask=...)
|
| 14 |
+
loss = nn.BCEWithLogitsLoss()(logit, label)
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
from transformers import LlavaForConditionalGeneration
|
| 20 |
+
from peft import LoraConfig
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# βββ LoRA Configuration ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 24 |
+
# Massive expansion: r=64 across ALL linear layers in the LLM backbone.
|
| 25 |
+
# We exclude lm_head (we discard it), fc1/fc2/out_proj (CLIP vision),
|
| 26 |
+
# and linear_1/linear_2 (multi-modal projector) from LoRA to keep
|
| 27 |
+
# the vision encoder frozen and only adapt the language transformer.
|
| 28 |
+
|
| 29 |
+
TMOS_LORA_CONFIG = LoraConfig(
|
| 30 |
+
r=64,
|
| 31 |
+
lora_alpha=128, # 2x rank as a common heuristic
|
| 32 |
+
target_modules=[
|
| 33 |
+
"q_proj", "k_proj", "v_proj", "o_proj",
|
| 34 |
+
"gate_proj", "up_proj", "down_proj",
|
| 35 |
+
],
|
| 36 |
+
lora_dropout=0.1,
|
| 37 |
+
bias="none",
|
| 38 |
+
task_type=None, # Custom classifier β not a causal LM
|
| 39 |
+
modules_to_save=["classifier"], # Always train the classification head
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class TMOSClassifier(nn.Module):
|
| 44 |
+
"""
|
| 45 |
+
Binary classifier built on the LLaVA transformer backbone.
|
| 46 |
+
|
| 47 |
+
Architecture:
|
| 48 |
+
pixel_values βββΊ CLIP Vision Tower βββΊ Multi-Modal Projector βββ
|
| 49 |
+
ββββΊ LLaMA Transformer βββΊ last_hidden_state[:, -1, :] βββΊ classifier βββΊ logit
|
| 50 |
+
input_ids βββΊ Token Embedding ββββββββββββββββββββββββββββββββββ
|
| 51 |
+
|
| 52 |
+
The lm_head is never used. We extract the final token's hidden state
|
| 53 |
+
and pass it through a learned nn.Linear(hidden_size, 1) head.
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
def __init__(self, base_model_id, torch_dtype=torch.float16, device_map="auto", token=None):
|
| 57 |
+
super().__init__()
|
| 58 |
+
|
| 59 |
+
# Load the full LLaVA model (we need vision tower + projector + LLM)
|
| 60 |
+
self.base = LlavaForConditionalGeneration.from_pretrained(
|
| 61 |
+
base_model_id,
|
| 62 |
+
torch_dtype=torch_dtype,
|
| 63 |
+
low_cpu_mem_usage=True,
|
| 64 |
+
device_map=device_map,
|
| 65 |
+
token=token,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
hidden_size = self.base.config.text_config.hidden_size # 4096 for 7B
|
| 69 |
+
|
| 70 |
+
# Freeze the lm_head β we won't use it, but freezing prevents
|
| 71 |
+
# wasted gradient computation if PEFT accidentally wraps it.
|
| 72 |
+
for param in self.base.lm_head.parameters():
|
| 73 |
+
param.requires_grad = False
|
| 74 |
+
|
| 75 |
+
# Keep the classifier head in fp32 for numerical stability.
|
| 76 |
+
self.classifier = nn.Linear(hidden_size, 1, dtype=torch.float32)
|
| 77 |
+
nn.init.xavier_uniform_(self.classifier.weight)
|
| 78 |
+
nn.init.zeros_(self.classifier.bias)
|
| 79 |
+
|
| 80 |
+
def forward(
|
| 81 |
+
self,
|
| 82 |
+
input_ids=None,
|
| 83 |
+
pixel_values=None,
|
| 84 |
+
attention_mask=None,
|
| 85 |
+
labels=None, # float tensor of shape (B,) β 0.0=real, 1.0=fake
|
| 86 |
+
**kwargs, # absorb extra keys from data collator
|
| 87 |
+
):
|
| 88 |
+
"""
|
| 89 |
+
Single deterministic forward pass β logit + optional BCE loss.
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
dict with keys:
|
| 93 |
+
"logit": (B, 1) raw logit
|
| 94 |
+
"loss": scalar BCE loss (only if labels provided)
|
| 95 |
+
"""
|
| 96 |
+
# ββ 1. Forward through the LLaVA backbone ββ
|
| 97 |
+
# We call the internal model (vision + projector + LLM) directly,
|
| 98 |
+
# asking for hidden states, NOT for language-model logits.
|
| 99 |
+
outputs = self.base.model(
|
| 100 |
+
input_ids=input_ids,
|
| 101 |
+
pixel_values=pixel_values,
|
| 102 |
+
attention_mask=attention_mask,
|
| 103 |
+
return_dict=True,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
# last_hidden_state: (B, seq_len, hidden_size)
|
| 107 |
+
last_hidden_state = outputs.last_hidden_state
|
| 108 |
+
|
| 109 |
+
# ββ 2. Pool: extract the final non-padded token per sequence ββ
|
| 110 |
+
if attention_mask is not None:
|
| 111 |
+
# Sum of mask gives the sequence length (excluding padding)
|
| 112 |
+
# Index of the last real token = seq_lengths - 1
|
| 113 |
+
seq_lengths = attention_mask.sum(dim=1).long() - 1
|
| 114 |
+
# Clamp to valid range
|
| 115 |
+
seq_lengths = seq_lengths.clamp(min=0, max=last_hidden_state.size(1) - 1)
|
| 116 |
+
# Gather the hidden state at each sequence's last real token
|
| 117 |
+
pooled = last_hidden_state[
|
| 118 |
+
torch.arange(last_hidden_state.size(0), device=last_hidden_state.device),
|
| 119 |
+
seq_lengths,
|
| 120 |
+
]
|
| 121 |
+
else:
|
| 122 |
+
# No mask β just take the last position
|
| 123 |
+
pooled = last_hidden_state[:, -1, :]
|
| 124 |
+
|
| 125 |
+
# Replace non-finite activations defensively before the classifier.
|
| 126 |
+
pooled = torch.nan_to_num(pooled, nan=0.0, posinf=1e4, neginf=-1e4)
|
| 127 |
+
|
| 128 |
+
# Match classifier device to pooled features when model is sharded/offloaded.
|
| 129 |
+
if self.classifier.weight.device != pooled.device:
|
| 130 |
+
self.classifier = self.classifier.to(pooled.device)
|
| 131 |
+
|
| 132 |
+
# ββ 3. Classify ββ
|
| 133 |
+
logit = self.classifier(pooled.float()) # (B, 1)
|
| 134 |
+
logit = torch.nan_to_num(logit, nan=0.0, posinf=20.0, neginf=-20.0)
|
| 135 |
+
|
| 136 |
+
result = {"logit": logit}
|
| 137 |
+
|
| 138 |
+
# ββ 4. Loss ββ
|
| 139 |
+
if labels is not None:
|
| 140 |
+
labels = labels.to(logit.dtype).to(logit.device)
|
| 141 |
+
if labels.dim() == 1:
|
| 142 |
+
labels = labels.unsqueeze(1) # (B,) β (B, 1)
|
| 143 |
+
loss_fn = nn.BCEWithLogitsLoss()
|
| 144 |
+
result["loss"] = loss_fn(logit, labels)
|
| 145 |
+
|
| 146 |
+
return result
|
| 147 |
+
|
| 148 |
+
def prepare_inputs_for_generation(self, *args, **kwargs):
|
| 149 |
+
"""Stub required by PEFT β we never generate text."""
|
| 150 |
+
raise NotImplementedError("TMOSClassifier does not support generation.")
|
| 151 |
+
|
| 152 |
+
def gradient_checkpointing_enable(self, **kwargs):
|
| 153 |
+
"""Delegate to the base model for HF Trainer compatibility."""
|
| 154 |
+
self.base.model.gradient_checkpointing_enable(**kwargs)
|
| 155 |
+
|
| 156 |
+
@property
|
| 157 |
+
def config(self):
|
| 158 |
+
"""Expose the base model config for PEFT."""
|
| 159 |
+
return self.base.config
|
| 160 |
+
|
| 161 |
+
@property
|
| 162 |
+
def device(self):
|
| 163 |
+
return next(self.parameters()).device
|
| 164 |
+
|
| 165 |
+
@property
|
| 166 |
+
def dtype(self):
|
| 167 |
+
return next(self.parameters()).dtype
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
# βββ Standalone Test ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 171 |
+
if __name__ == "__main__":
|
| 172 |
+
import os
|
| 173 |
+
from dotenv import load_dotenv
|
| 174 |
+
load_dotenv()
|
| 175 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 176 |
+
|
| 177 |
+
print("Testing TMOSClassifier...")
|
| 178 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 179 |
+
|
| 180 |
+
clf = TMOSClassifier(
|
| 181 |
+
base_model_id="llava-hf/llava-1.5-7b-hf",
|
| 182 |
+
torch_dtype=torch.float16,
|
| 183 |
+
token=HF_TOKEN,
|
| 184 |
+
)
|
| 185 |
+
clf.to(device)
|
| 186 |
+
|
| 187 |
+
# Print parameter counts
|
| 188 |
+
total = sum(p.numel() for p in clf.parameters())
|
| 189 |
+
trainable = sum(p.numel() for p in clf.parameters() if p.requires_grad)
|
| 190 |
+
print(f"Total params: {total:>12,}")
|
| 191 |
+
print(f"Trainable params: {trainable:>12,}")
|
| 192 |
+
print(f"Classifier head: {sum(p.numel() for p in clf.classifier.parameters()):,}")
|
| 193 |
+
|
| 194 |
+
# Smoke test with dummy input
|
| 195 |
+
from transformers import AutoProcessor
|
| 196 |
+
processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf", token=HF_TOKEN)
|
| 197 |
+
processor.patch_size = 14
|
| 198 |
+
processor.vision_feature_select_strategy = "default"
|
| 199 |
+
|
| 200 |
+
from PIL import Image
|
| 201 |
+
dummy_img = Image.new("RGB", (336, 336), color=(128, 128, 128))
|
| 202 |
+
inputs = processor(
|
| 203 |
+
text="USER: <image>\nIs this real?\nASSISTANT:",
|
| 204 |
+
images=dummy_img,
|
| 205 |
+
return_tensors="pt",
|
| 206 |
+
).to(device)
|
| 207 |
+
|
| 208 |
+
labels = torch.tensor([1.0], device=device) # fake
|
| 209 |
+
|
| 210 |
+
with torch.no_grad():
|
| 211 |
+
out = clf(**inputs, labels=labels)
|
| 212 |
+
|
| 213 |
+
print(f"Logit: {out['logit'].item():.4f}")
|
| 214 |
+
print(f"Loss: {out['loss'].item():.4f}")
|
| 215 |
+
print(f"Prob: {torch.sigmoid(out['logit']).item():.4f}")
|
| 216 |
+
print("Test passed.")
|