screen-on-off-classifier / screen_classifier.py
dhruvguptaa's picture
Upload trained screen ON/OFF classifier
73eaad9 verified
#!/usr/bin/env python3
"""
Complete Screen ON/OFF Training Pipeline
=========================================
1. Generates synthetic dataset (2000 images/class)
2. Trains the lightweight CNN with early stopping
3. Evaluates with detailed metrics
4. Saves model (.pth + TorchScript .pt)
5. Pushes model to Hugging Face Hub
6. Uploads dataset to HF Hub
Can be run standalone or via hf_jobs.
"""
import os
import sys
import time
import copy
import random
import math
import json
import argparse
from pathlib import Path
from typing import Tuple
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from PIL import Image, ImageDraw, ImageFilter
# ──────────────────────────────────────────────────────────────────
# CONFIG
# ──────────────────────────────────────────────────────────────────
HUB_MODEL_ID = os.environ.get("HUB_MODEL_ID", "dhruvguptaa/screen-on-off-classifier")
HUB_DATASET_ID = os.environ.get("HUB_DATASET_ID", "dhruvguptaa/screen-on-off-dataset")
N_PER_CLASS = int(os.environ.get("N_PER_CLASS", "2000"))
EPOCHS = int(os.environ.get("EPOCHS", "80"))
BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "32"))
LR = float(os.environ.get("LR", "1e-3"))
PATIENCE = int(os.environ.get("PATIENCE", "10"))
SEED = int(os.environ.get("SEED", "42"))
DATA_DIR = "/tmp/screen_data"
SAVE_DIR = "/tmp/model_output"
# ──────────────────────────────────────────────────────────────────
# SYNTHETIC DATA GENERATION
# ──────────────────────────────────────────────────────────────────
def gaussian_blob(h, w, cx, cy, sigma, intensity):
y, x = np.ogrid[:h, :w]
g = np.exp(-((x - cx)**2 + (y - cy)**2) / (2 * sigma**2))
return (g * intensity).astype(np.float32)
def random_color(min_v=0, max_v=255):
return tuple(random.randint(min_v, max_v) for _ in range(3))
def random_bright_color():
palettes = [
(66, 133, 244), (234, 67, 53), (251, 188, 4), (52, 168, 83),
(255, 255, 255), (245, 245, 245), (33, 150, 243), (76, 175, 80),
(255, 152, 0), (156, 39, 176), (0, 188, 212), (255, 87, 34),
(63, 81, 181), (0, 150, 136), (121, 85, 72),
]
return random.choice(palettes)
def add_noise(img_array, strength=5):
noise = np.random.normal(0, strength, img_array.shape).astype(np.float32)
return np.clip(img_array.astype(np.float32) + noise, 0, 255).astype(np.uint8)
def generate_on_screen(size=(128, 128)):
"""Generate a realistic 'screen ON' image with varied UI layouts."""
w, h = size
img = Image.new("RGB", (w, h))
draw = ImageDraw.Draw(img)
style = random.choice([
"light_app", "dark_app", "media", "gradient_bg", "notification",
"chat", "settings", "browser", "colorful_app", "lock_screen_on",
])
if style == "light_app":
bg = random.choice([(255,255,255), (245,245,245), (250,250,250)])
draw.rectangle([0, 0, w, h], fill=bg)
bar_h = random.randint(8, 14)
draw.rectangle([0, 0, w, bar_h], fill=random.choice([(50,50,50), (33,33,33), (66,133,244)]))
for x_pos in [5, 12, 19]:
draw.rectangle([x_pos, 2, x_pos+4, bar_h-2], fill=(200,200,200))
for x_pos in [w-25, w-18, w-10]:
draw.rectangle([x_pos, 2, x_pos+5, bar_h-2], fill=(200,200,200))
title_y = bar_h + 2
draw.rectangle([0, title_y, w, title_y+12], fill=random_bright_color())
y = title_y + 18
while y < h - 20:
lw = random.randint(w//3, w-10)
draw.rectangle([8, y, 8+lw, y+random.randint(2,4)], fill=(random.randint(60,160),)*3)
y += random.randint(6, 14)
if random.random() > 0.4:
btn_y = random.randint(h//2, h-20)
draw.rounded_rectangle([w//4, btn_y, 3*w//4, btn_y+12], radius=4, fill=random_bright_color())
elif style == "dark_app":
bg = random.choice([(18,18,18), (30,30,30), (20,20,30), (25,25,35)])
draw.rectangle([0, 0, w, h], fill=bg)
bar_h = random.randint(8, 14)
draw.rectangle([0, 0, w, bar_h], fill=(35,35,35))
accent = random_bright_color()
nav_h = random.randint(10, 16)
draw.rectangle([0, h-nav_h, w, h], fill=(30,30,30))
n_tabs = random.randint(3, 5)
tab_w = w // n_tabs
active = random.randint(0, n_tabs-1)
for i in range(n_tabs):
x = i * tab_w + tab_w//2 - 3
draw.rectangle([x, h-nav_h+3, x+6, h-3], fill=accent if i == active else (100,100,100))
y = bar_h + 8
while y < h - nav_h - 10:
card_h = random.randint(15, 30)
draw.rounded_rectangle([6, y, w-6, y+card_h], radius=3, fill=(45,45,50))
for ly in range(y+4, min(y+card_h-4, h), 5):
draw.rectangle([10, ly, 10+random.randint(20, w-20), ly+2], fill=(160,160,160))
y += card_h + 4
elif style == "media":
arr = np.zeros((h, w, 3), dtype=np.uint8)
c1 = np.array(random_bright_color(), dtype=np.float32)
c2 = np.array(random_bright_color(), dtype=np.float32)
for row in range(h):
arr[row] = (c1 * (1 - row/h) + c2 * (row/h)).astype(np.uint8)
img = Image.fromarray(arr)
draw = ImageDraw.Draw(img)
if random.random() > 0.3:
cx, cy = w//2, h//2
s = random.randint(10, 20)
draw.polygon([(cx-s//2, cy-s), (cx-s//2, cy+s), (cx+s, cy)], fill=(255,255,255))
bar_y = h - random.randint(15, 25)
p = random.uniform(0.1, 0.9)
draw.rectangle([5, bar_y, w-5, bar_y+3], fill=(100,100,100))
draw.rectangle([5, bar_y, 5+int((w-10)*p), bar_y+3], fill=(255,0,0))
elif style == "gradient_bg":
arr = np.zeros((h, w, 3), dtype=np.float32)
c1 = np.array(random_bright_color(), dtype=np.float32)
c2 = np.array(random_bright_color(), dtype=np.float32)
for row in range(h):
for col in range(w):
t = (row/h + col/w) / 2
arr[row, col] = c1*(1-t) + c2*t
img = Image.fromarray(arr.astype(np.uint8))
draw = ImageDraw.Draw(img)
icon_size = random.randint(10, 16)
cols_n = random.randint(3, 5)
rows_n = random.randint(3, 5)
x_start = (w - cols_n * (icon_size + 6)) // 2
y_start = random.randint(20, 35)
for r in range(rows_n):
for c in range(cols_n):
x = x_start + c*(icon_size+6)
y = y_start + r*(icon_size+10)
if y + icon_size < h - 15:
draw.rounded_rectangle([x, y, x+icon_size, y+icon_size], radius=3, fill=random_bright_color())
elif style == "notification":
draw.rectangle([0, 0, w, h], fill=random.choice([(30,30,40), (20,20,25), (40,40,50)]))
draw.rectangle([w//4, 10, 3*w//4, 28], fill=(220,220,220))
y = 35
for _ in range(random.randint(2, 5)):
if y > h - 20: break
ch = random.randint(18, 30)
draw.rounded_rectangle([8, y, w-8, y+ch], radius=4, fill=(50,50,60))
draw.ellipse([12, y+4, 22, y+14], fill=random_bright_color())
draw.rectangle([26, y+5, w-15, y+8], fill=(200,200,200))
draw.rectangle([26, y+11, random.randint(w//2, w-15), y+14], fill=(150,150,150))
y += ch + 5
elif style == "chat":
is_dark = random.random() > 0.5
bg = (18,18,22) if is_dark else random.choice([(230,230,235), (255,255,255)])
draw.rectangle([0, 0, w, h], fill=bg)
draw.rectangle([0, 0, w, 16], fill=random_bright_color())
y = 22
while y < h - 25:
is_sent = random.random() > 0.5
bw = random.randint(w//3, 2*w//3)
bh = random.randint(10, 22)
x1 = w - bw - 8 if is_sent else 8
color = (0,132,255) if is_sent else ((50,50,55) if is_dark else (229,229,234))
draw.rounded_rectangle([x1, y, x1+bw, y+bh], radius=5, fill=color)
text_c = (255,255,255) if is_sent or is_dark else (30,30,30)
for ly in range(y+3, min(y+bh-3, h), 5):
draw.rectangle([x1+5, ly, x1+5+random.randint(bw//3, bw-8), ly+2], fill=text_c)
y += bh + random.randint(4, 10)
draw.rectangle([0, h-14, w, h], fill=(35,35,40) if is_dark else (240,240,240))
elif style == "settings":
is_dark = random.random() > 0.5
bg = (0,0,0) if is_dark else (242,242,247)
draw.rectangle([0, 0, w, h], fill=bg)
bar_h = 16
draw.rectangle([w//4, 4, 3*w//4, 12], fill=(255,255,255) if is_dark else (0,0,0))
y = bar_h + 4
row_bg = (28,28,30) if is_dark else (255,255,255)
while y < h - 8:
rh = random.randint(12, 18)
draw.rectangle([0, y, w, y+rh], fill=row_bg)
draw.rounded_rectangle([8, y+3, 18, y+rh-3], radius=2, fill=random_bright_color())
label_c = (220,220,220) if is_dark else (30,30,30)
draw.rectangle([24, y+rh//2-1, 24+random.randint(30, w-40), y+rh//2+1], fill=label_c)
y += rh
elif style == "browser":
draw.rectangle([0, 0, w, h], fill=(255,255,255))
bar_h = 14
draw.rectangle([0, 0, w, bar_h], fill=(245,245,245))
draw.rounded_rectangle([8, 2, w-8, bar_h-2], radius=4, fill=(255,255,255), outline=(200,200,200))
draw.rectangle([14, 5, w//2, 9], fill=(100,100,100))
y = bar_h + 6
while y < h - 10:
elem = random.choice(["text", "image", "heading"])
if elem == "text":
for _ in range(random.randint(2, 5)):
if y > h-8: break
draw.rectangle([8, y, 8+random.randint(w//2, w-12), y+2], fill=(50,50,50))
y += 5
elif elem == "image":
ih = random.randint(20, 40)
draw.rectangle([8, y, w-8, y+ih], fill=random_color(120, 240))
y += ih + 4
else:
draw.rectangle([8, y, w//2+20, y+5], fill=(20,20,20))
y += 10
y += random.randint(4, 10)
elif style == "colorful_app":
arr = np.random.randint(60, 255, (h, w, 3), dtype=np.uint8)
img = Image.fromarray(arr).filter(ImageFilter.GaussianBlur(radius=8))
draw = ImageDraw.Draw(img)
draw.rectangle([0, 0, w, 12], fill=(0,0,0))
for x in [6, 14, 22]:
draw.rectangle([x, 3, x+5, 9], fill=(255,255,255))
else: # lock_screen_on
c1 = np.array(random_bright_color(), dtype=np.float32)
c2 = np.array(random_bright_color(), dtype=np.float32) * 0.4
arr = np.zeros((h, w, 3), dtype=np.float32)
for row in range(h):
arr[row] = c1 * (1-row/h) + c2 * (row/h)
img = Image.fromarray(arr.astype(np.uint8)).filter(ImageFilter.GaussianBlur(radius=3))
draw = ImageDraw.Draw(img)
draw.rectangle([w//4, h//4, 3*w//4, h//4+20], fill=(255,255,255))
draw.rectangle([w//3, h//4+25, 2*w//3, h//4+30], fill=(220,220,220))
# Post-processing
arr = np.array(img, dtype=np.float32)
if arr.mean() < 80:
arr = arr * (100 / max(arr.mean(), 1))
arr = np.clip(arr, 0, 255)
# Vignette
if random.random() > 0.3:
vs = random.uniform(0.02, 0.08)
y_v, x_v = np.ogrid[:h, :w]
r = np.sqrt((x_v - w/2)**2 + (y_v - h/2)**2)
r_max = np.sqrt((w/2)**2 + (h/2)**2)
arr = arr * (1 - vs * (r / r_max)**2)[:, :, np.newaxis]
# Glass reflection
if random.random() > 0.5:
rs = random.uniform(0.01, 0.04)
blob = gaussian_blob(h, w, random.randint(0,w), random.randint(0,h), random.uniform(30,60), 255*rs)
arr += blob[:, :, np.newaxis]
arr = np.clip(arr, 0, 255).astype(np.uint8)
arr = add_noise(arr, random.uniform(2, 6))
return Image.fromarray(arr)
def generate_off_screen(size=(128, 128)):
"""Generate a realistic 'screen OFF' image with dark glass effects."""
w, h = size
style = random.choice([
"dark_clean", "dark_glare", "reflection_heavy", "fingerprints",
"ambient_bright", "sunset_reflection", "indoor_reflection",
"near_black", "blue_ambient",
])
if style == "near_black":
arr = np.random.uniform(0, 8, (h, w, 3)).astype(np.float32)
elif style == "blue_ambient":
arr = np.full((h, w, 3), [random.uniform(5,20), random.uniform(8,25), random.uniform(15,40)], dtype=np.float32)
arr += np.random.uniform(-3, 3, (h, w, 3))
elif style == "ambient_bright":
base = random.uniform(15, 45)
arr = np.full((h, w, 3), base, dtype=np.float32)
if random.random() > 0.5:
arr[:,:,0] += random.uniform(5, 15)
arr[:,:,2] -= random.uniform(0, 5)
else:
arr[:,:,2] += random.uniform(5, 15)
arr[:,:,0] -= random.uniform(0, 5)
elif style == "sunset_reflection":
arr = np.zeros((h, w, 3), dtype=np.float32)
for row in range(h):
t = row / h
arr[row,:,0] = random.uniform(20, 60) * (1-t)
arr[row,:,1] = random.uniform(10, 30) * (1-t)
arr[row,:,2] = random.uniform(5, 20) * (1-t)
else:
base = random.uniform(3, 25)
arr = np.full((h, w, 3), base, dtype=np.float32)
arr[:,:,0] += random.uniform(-3, 10)
arr[:,:,1] += random.uniform(-3, 10)
arr[:,:,2] += random.uniform(-3, 15)
# Glare
if style in ["dark_glare", "ambient_bright", "indoor_reflection"] or random.random() > 0.4:
for _ in range(random.randint(1, 3)):
blob = gaussian_blob(h, w, random.randint(0,w), random.randint(0,h),
random.uniform(10,50), random.uniform(30,180))
glare_c = np.array([random.uniform(0.8,1), random.uniform(0.8,1), random.uniform(0.8,1)])
arr += blob[:,:,np.newaxis] * glare_c
# Scene reflection
if style in ["reflection_heavy", "indoor_reflection", "sunset_reflection"] or random.random() > 0.5:
refl = np.zeros((h, w, 3), dtype=np.float32)
for _ in range(random.randint(2, 6)):
rx1, ry1 = random.randint(0, w-10), random.randint(0, h-10)
rx2 = min(w, rx1 + random.randint(10, 50))
ry2 = min(h, ry1 + random.randint(10, 50))
refl[ry1:ry2, rx1:rx2] = np.array(random_color(20, 120), dtype=np.float32)
refl_img = Image.fromarray(refl.astype(np.uint8)).filter(
ImageFilter.GaussianBlur(radius=random.uniform(8, 20)))
arr += np.array(refl_img, dtype=np.float32) * random.uniform(0.04, 0.15)
# Fingerprints
if style == "fingerprints" or random.random() > 0.5:
for _ in range(random.randint(1, 4)):
cx, cy = random.randint(10, w-10), random.randint(10, h-10)
angle = random.uniform(0, math.pi)
sx, sy = random.uniform(8, 25), random.uniform(5, 15)
y_fp, x_fp = np.ogrid[:h, :w]
dx = (x_fp-cx)*math.cos(angle) + (y_fp-cy)*math.sin(angle)
dy = -(x_fp-cx)*math.sin(angle) + (y_fp-cy)*math.cos(angle)
arr += (np.exp(-((dx/sx)**2 + (dy/sy)**2)*2) * random.uniform(3, 12))[:,:,np.newaxis]
# Ambient gradient
if random.random() > 0.3:
direction = random.choice(["top", "bottom", "left", "right", "corner"])
if direction == "top":
grad = np.linspace(1, 0, h)[:, np.newaxis] * np.ones(w)
elif direction == "bottom":
grad = np.linspace(0, 1, h)[:, np.newaxis] * np.ones(w)
elif direction == "left":
grad = np.ones((h, 1)) * np.linspace(1, 0, w)
elif direction == "right":
grad = np.ones((h, 1)) * np.linspace(0, 1, w)
else:
y_g, x_g = np.ogrid[:h, :w]
grad = np.sqrt((x_g/w)**2 + (y_g/h)**2) / np.sqrt(2)
arr += grad[:,:,np.newaxis] * random.uniform(3, 20)
# Cap brightness
arr = np.clip(arr, 0, 255)
if arr.mean() > 70:
arr = arr * (60 / arr.mean())
arr = np.clip(arr, 0, 255).astype(np.uint8)
arr = add_noise(arr, random.uniform(2, 8))
return Image.fromarray(arr)
def generate_dataset(output_dir, n_per_class=2000, size=(128, 128), seed=42):
random.seed(seed)
np.random.seed(seed)
on_dir = Path(output_dir) / "on"
off_dir = Path(output_dir) / "off"
on_dir.mkdir(parents=True, exist_ok=True)
off_dir.mkdir(parents=True, exist_ok=True)
print(f"Generating {n_per_class} ON images...")
for i in range(n_per_class):
generate_on_screen(size).save(on_dir / f"on_{i:05d}.png")
if (i + 1) % 500 == 0:
print(f" ON: {i+1}/{n_per_class}")
print(f"Generating {n_per_class} OFF images...")
for i in range(n_per_class):
generate_off_screen(size).save(off_dir / f"off_{i:05d}.png")
if (i + 1) % 500 == 0:
print(f" OFF: {i+1}/{n_per_class}")
print(f"Dataset: {n_per_class*2} images in {output_dir}")
return output_dir
# ──────────────────────────────────────────────────────────────────
# MODEL
# ──────────────────────────────────────────────────────────────────
class ScreenClassifier(nn.Module):
"""
Tiny 3-layer CNN for binary screen ON/OFF classification.
Input: [B, 1, 64, 64] grayscale | Output: [B, 1] logit
"""
def __init__(self):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(1, 16, 3, padding=1, bias=False), nn.BatchNorm2d(16), nn.ReLU(True), nn.MaxPool2d(2),
nn.Conv2d(16, 32, 3, padding=1, bias=False), nn.BatchNorm2d(32), nn.ReLU(True), nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1, bias=False), nn.BatchNorm2d(64), nn.ReLU(True), nn.AdaptiveAvgPool2d(1),
)
self.classifier = nn.Sequential(nn.Flatten(), nn.Dropout(0.3), nn.Linear(64, 1))
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight); nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, mode="fan_in", nonlinearity="relu")
if m.bias is not None: nn.init.zeros_(m.bias)
def forward(self, x):
return self.classifier(self.features(x))
# ──────────────────────────────────────────────────────────────────
# EARLY STOPPING
# ──────────────────────────────────────────────────────────────────
class EarlyStopping:
def __init__(self, patience=10, min_delta=1e-4):
self.patience = patience
self.min_delta = min_delta
self.best_loss = float("inf")
self.best_state = None
self.counter = 0
self.should_stop = False
def step(self, val_loss, model):
if val_loss < self.best_loss - self.min_delta:
self.best_loss = val_loss
self.best_state = copy.deepcopy(model.state_dict())
self.counter = 0
else:
self.counter += 1
if self.counter >= self.patience:
self.should_stop = True
return self.should_stop
def restore_best(self, model):
if self.best_state:
model.load_state_dict(self.best_state)
# ──────────────────────────────────────────────────────────────────
# DATA LOADING
# ──────────────────────────────────────────────────────────────────
class _TransformSubset(torch.utils.data.Dataset):
def __init__(self, dataset, indices, transform):
self.dataset = dataset
self.indices = list(indices)
self.transform = transform
def __len__(self):
return len(self.indices)
def __getitem__(self, idx):
img_path, label = self.dataset.samples[self.indices[idx]]
img = Image.open(img_path).convert("RGB")
if self.transform:
img = self.transform(img)
return img, label
def get_transforms(augment=True):
if augment:
return transforms.Compose([
transforms.Grayscale(1),
transforms.Resize((72, 72)),
transforms.RandomCrop((64, 64)),
transforms.RandomHorizontalFlip(0.5),
transforms.RandomRotation(10),
transforms.RandomAffine(0, translate=(0.08, 0.08), scale=(0.9, 1.1)),
transforms.ColorJitter(brightness=0.3, contrast=0.3),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])
else:
return transforms.Compose([
transforms.Grayscale(1),
transforms.Resize((64, 64)),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])
def build_dataloaders(data_dir, batch_size=32, val_split=0.2, seed=42):
full_ds = datasets.ImageFolder(root=data_dir, transform=get_transforms(True))
print(f"[DATA] Classes: {full_ds.class_to_idx}, Total: {len(full_ds)}")
n_total = len(full_ds)
n_val = max(1, int(n_total * val_split))
n_train = n_total - n_val
gen = torch.Generator().manual_seed(seed)
train_idx, val_idx = random_split(range(n_total), [n_train, n_val], generator=gen)
train_ds = _TransformSubset(full_ds, train_idx.indices, get_transforms(True))
val_ds = _TransformSubset(full_ds, val_idx.indices, get_transforms(False))
print(f"[DATA] Train: {len(train_ds)}, Val: {len(val_ds)}")
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=False)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=False)
return train_loader, val_loader, full_ds.class_to_idx
# ──────────────────────────────────────────────────────────────────
# TRAINING
# ──────────────────────────────────────────────────────────────────
def train_model(data_dir, epochs, batch_size, lr, patience, seed, save_dir):
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
os.makedirs(save_dir, exist_ok=True)
device = torch.device("cpu")
print(f"[TRAIN] Device: {device}")
train_loader, val_loader, class_to_idx = build_dataloaders(data_dir, batch_size, seed=seed)
model = ScreenClassifier().to(device)
n_params = sum(p.numel() for p in model.parameters())
print(f"[MODEL] Parameters: {n_params:,}")
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
early_stop = EarlyStopping(patience=patience)
history = {"train_loss": [], "val_loss": [], "train_acc": [], "val_acc": []}
best_val_acc = 0.0
for epoch in range(1, epochs + 1):
# Train
model.train()
t_loss, t_correct, t_total = 0, 0, 0
for imgs, labels in train_loader:
imgs, labels = imgs.to(device), labels.to(device).float()
logits = model(imgs).squeeze(1)
loss = criterion(logits, labels)
optimizer.zero_grad(); loss.backward(); optimizer.step()
t_loss += loss.item() * imgs.size(0)
t_correct += ((torch.sigmoid(logits) >= 0.5).float() == labels).sum().item()
t_total += imgs.size(0)
t_loss /= t_total
t_acc = t_correct / t_total
# Validate
model.eval()
v_loss, v_correct, v_total = 0, 0, 0
all_preds, all_labels = [], []
with torch.no_grad():
for imgs, labels in val_loader:
imgs, labels = imgs.to(device), labels.to(device).float()
logits = model(imgs).squeeze(1)
loss = criterion(logits, labels)
v_loss += loss.item() * imgs.size(0)
preds = (torch.sigmoid(logits) >= 0.5).float()
v_correct += (preds == labels).sum().item()
v_total += imgs.size(0)
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
v_loss /= v_total
v_acc = v_correct / v_total
scheduler.step()
lr_now = optimizer.param_groups[0]["lr"]
history["train_loss"].append(t_loss)
history["val_loss"].append(v_loss)
history["train_acc"].append(t_acc)
history["val_acc"].append(v_acc)
if v_acc > best_val_acc:
best_val_acc = v_acc
print(
f"Epoch {epoch:3d}/{epochs} | "
f"Train Loss: {t_loss:.4f} Acc: {t_acc:.4f} | "
f"Val Loss: {v_loss:.4f} Acc: {v_acc:.4f} | "
f"LR: {lr_now:.6f} | ES: {early_stop.counter}/{early_stop.patience}"
)
if early_stop.step(v_loss, model):
print(f"\n[EARLY STOP] Epoch {epoch}. Restoring best weights.")
early_stop.restore_best(model)
break
# ── Final evaluation ──
model.eval()
v_loss, v_correct, v_total = 0, 0, 0
tp, tn, fp, fn = 0, 0, 0, 0
with torch.no_grad():
for imgs, labels in val_loader:
imgs, labels = imgs.to(device), labels.to(device).float()
logits = model(imgs).squeeze(1)
preds = (torch.sigmoid(logits) >= 0.5).float()
v_correct += (preds == labels).sum().item()
v_total += imgs.size(0)
tp += ((preds == 1) & (labels == 1)).sum().item()
tn += ((preds == 0) & (labels == 0)).sum().item()
fp += ((preds == 1) & (labels == 0)).sum().item()
fn += ((preds == 0) & (labels == 1)).sum().item()
final_acc = v_correct / v_total
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
print(f"\n{'='*60}")
print(f"FINAL EVALUATION")
print(f"{'='*60}")
print(f"Accuracy: {final_acc:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")
print(f"TP={tp}, TN={tn}, FP={fp}, FN={fn}")
print(f"{'='*60}")
# ── Save ──
pth_path = os.path.join(save_dir, "screen_classifier_best.pth")
torch.save(model.state_dict(), pth_path)
print(f"[SAVE] State dict: {pth_path}")
pt_path = os.path.join(save_dir, "screen_classifier_best.pt")
try:
scripted = torch.jit.script(model)
scripted.save(pt_path)
print(f"[SAVE] TorchScript: {pt_path}")
except Exception as e:
print(f"[WARN] TorchScript failed: {e}")
# Save metrics
metrics = {
"accuracy": final_acc, "precision": precision, "recall": recall, "f1": f1,
"tp": tp, "tn": tn, "fp": fp, "fn": fn,
"best_val_acc": best_val_acc, "n_params": n_params,
"epochs_trained": epoch, "class_to_idx": class_to_idx,
}
with open(os.path.join(save_dir, "metrics.json"), "w") as f:
json.dump(metrics, f, indent=2)
# Benchmark
dummy = torch.randn(1, 1, 64, 64)
times = []
for _ in range(200):
t0 = time.perf_counter()
with torch.no_grad():
model(dummy)
times.append(time.perf_counter() - t0)
avg_ms = np.mean(times) * 1000
print(f"[BENCH] Inference: {avg_ms:.2f} ms avg (200 runs)")
return model, metrics
# ──────────────────────────────────────────────────────────────────
# HUB UPLOAD
# ──────────────────────────────────────────────────────────────────
def push_to_hub(save_dir, hub_model_id, hub_dataset_id, data_dir):
"""Push model + dataset to Hugging Face Hub."""
from huggingface_hub import HfApi, upload_folder
api = HfApi()
token = os.environ.get("HF_TOKEN")
# ── Upload model ──
print(f"\n[HUB] Pushing model to {hub_model_id}...")
try:
api.create_repo(hub_model_id, exist_ok=True, token=token)
except Exception:
pass
# Create model card
with open(os.path.join(save_dir, "metrics.json")) as f:
metrics = json.load(f)
model_card = f"""---
tags:
- pytorch
- image-classification
- binary-classification
- screen-detection
- lightweight
- cpu-optimized
license: mit
metrics:
- accuracy
- f1
---
# Screen ON/OFF Classifier
Ultra-lightweight CNN (~23K params) that classifies phone screen images as **ON** or **OFF**.
Designed for real-time CPU inference (<1ms per frame).
## Performance
| Metric | Value |
|--------|-------|
| Accuracy | {metrics['accuracy']:.4f} |
| Precision | {metrics['precision']:.4f} |
| Recall | {metrics['recall']:.4f} |
| F1 Score | {metrics['f1']:.4f} |
| Parameters | {metrics['n_params']:,} |
| Inference | <1ms (CPU) |
## Usage
```python
import numpy as np
import cv2
import torch
import torch.nn as nn
class ScreenClassifier(nn.Module):
def __init__(self):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(1, 16, 3, padding=1, bias=False), nn.BatchNorm2d(16), nn.ReLU(True), nn.MaxPool2d(2),
nn.Conv2d(16, 32, 3, padding=1, bias=False), nn.BatchNorm2d(32), nn.ReLU(True), nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1, bias=False), nn.BatchNorm2d(64), nn.ReLU(True), nn.AdaptiveAvgPool2d(1),
)
self.classifier = nn.Sequential(nn.Flatten(), nn.Dropout(0.3), nn.Linear(64, 1))
def forward(self, x):
return self.classifier(self.features(x))
# Load
model = ScreenClassifier()
model.load_state_dict(torch.load("screen_classifier_best.pth", map_location="cpu", weights_only=True))
model.eval()
# Predict from OpenCV frame
frame = cv2.imread("phone_screen.jpg")
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
resized = cv2.resize(gray, (64, 64), interpolation=cv2.INTER_AREA)
tensor = torch.from_numpy(resized.astype(np.float32)).div(255.0)
tensor = (tensor - 0.5) / 0.5
tensor = tensor.unsqueeze(0).unsqueeze(0)
with torch.no_grad():
prob = torch.sigmoid(model(tensor).squeeze()).item()
label = "ON" if prob >= 0.5 else "OFF"
confidence = prob if label == "ON" else 1.0 - prob
print(f"{{label}} (confidence: {{confidence:.1%}})")
```
## Training
Trained on synthetic data ({metrics.get('n_params', 'N/A')} params) with domain randomization.
- ON screens: 10 UI layout styles (light/dark apps, chat, media, settings, browser, etc.)
- OFF screens: 9 dark glass variations (glare, reflections, fingerprints, ambient lighting)
## Files
- `screen_classifier_best.pth` β€” PyTorch state dict
- `screen_classifier_best.pt` β€” TorchScript (deploy without class definition)
- `metrics.json` β€” Training metrics
- `screen_classifier.py` β€” Full training + inference code
"""
with open(os.path.join(save_dir, "README.md"), "w") as f:
f.write(model_card)
upload_folder(
folder_path=save_dir,
repo_id=hub_model_id,
token=token,
commit_message="Upload trained screen ON/OFF classifier",
)
print(f"[HUB] Model pushed: https://huggingface.co/{hub_model_id}")
# ── Upload dataset ──
print(f"[HUB] Pushing dataset to {hub_dataset_id}...")
try:
api.create_repo(hub_dataset_id, repo_type="dataset", exist_ok=True, token=token)
except Exception:
pass
dataset_card = f"""---
tags:
- image-classification
- binary-classification
- screen-detection
- synthetic
license: mit
task_categories:
- image-classification
---
# Screen ON/OFF Dataset (Synthetic)
Synthetic dataset for training phone screen ON/OFF binary classifiers.
## Structure
- `on/` β€” {N_PER_CLASS} images of screens in ON state (various UI layouts)
- `off/` β€” {N_PER_CLASS} images of screens in OFF state (dark glass with reflections, glare, fingerprints)
## Generation
Generated using domain randomization (arxiv:1703.06907 principles):
- **ON**: 10 UI styles (light/dark apps, chat, media, browser, settings, notifications, etc.)
- **OFF**: 9 glass surface styles (clean dark, glare, reflections, fingerprints, ambient lighting)
Images are 128x128 RGB PNG. Training resizes to 64x64 grayscale.
"""
dataset_readme = os.path.join(data_dir, "README.md")
with open(dataset_readme, "w") as f:
f.write(dataset_card)
upload_folder(
folder_path=data_dir,
repo_id=hub_dataset_id,
repo_type="dataset",
token=token,
commit_message="Upload synthetic screen ON/OFF dataset",
)
print(f"[HUB] Dataset pushed: https://huggingface.co/datasets/{hub_dataset_id}")
# ──────────────────────────────────────────────────────────────────
# MAIN
# ──────────────────────────────────────────────────────────────────
def main():
print("=" * 60)
print("SCREEN ON/OFF CLASSIFIER β€” FULL PIPELINE")
print("=" * 60)
# 1. Generate dataset
print("\n[STEP 1] Generating synthetic dataset...")
t0 = time.time()
generate_dataset(DATA_DIR, N_PER_CLASS, (128, 128), SEED)
print(f"Dataset generation: {time.time()-t0:.1f}s")
# 2. Train model
print(f"\n[STEP 2] Training model (epochs={EPOCHS}, bs={BATCH_SIZE}, lr={LR})...")
t0 = time.time()
model, metrics = train_model(DATA_DIR, EPOCHS, BATCH_SIZE, LR, PATIENCE, SEED, SAVE_DIR)
print(f"Training: {time.time()-t0:.1f}s")
# 3. Copy training script to save dir
import shutil
script_src = os.path.abspath(__file__)
shutil.copy2(script_src, os.path.join(SAVE_DIR, "screen_classifier.py"))
# 4. Push to Hub
print(f"\n[STEP 3] Pushing to Hugging Face Hub...")
try:
push_to_hub(SAVE_DIR, HUB_MODEL_ID, HUB_DATASET_ID, DATA_DIR)
except Exception as e:
print(f"[WARN] Hub push failed: {e}")
print("Model files are still saved locally.")
print("\n" + "=" * 60)
print("DONE!")
print(f"Model: https://huggingface.co/{HUB_MODEL_ID}")
print(f"Dataset: https://huggingface.co/datasets/{HUB_DATASET_ID}")
print("=" * 60)
if __name__ == "__main__":
main()