#!/usr/bin/env python3 """ Complete Screen ON/OFF Training Pipeline ========================================= 1. Generates synthetic dataset (2000 images/class) 2. Trains the lightweight CNN with early stopping 3. Evaluates with detailed metrics 4. Saves model (.pth + TorchScript .pt) 5. Pushes model to Hugging Face Hub 6. Uploads dataset to HF Hub Can be run standalone or via hf_jobs. """ import os import sys import time import copy import random import math import json import argparse from pathlib import Path from typing import Tuple import numpy as np import torch import torch.nn as nn from torch.utils.data import DataLoader, random_split from torchvision import datasets, transforms from PIL import Image, ImageDraw, ImageFilter # ────────────────────────────────────────────────────────────────── # CONFIG # ────────────────────────────────────────────────────────────────── HUB_MODEL_ID = os.environ.get("HUB_MODEL_ID", "dhruvguptaa/screen-on-off-classifier") HUB_DATASET_ID = os.environ.get("HUB_DATASET_ID", "dhruvguptaa/screen-on-off-dataset") N_PER_CLASS = int(os.environ.get("N_PER_CLASS", "2000")) EPOCHS = int(os.environ.get("EPOCHS", "80")) BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "32")) LR = float(os.environ.get("LR", "1e-3")) PATIENCE = int(os.environ.get("PATIENCE", "10")) SEED = int(os.environ.get("SEED", "42")) DATA_DIR = "/tmp/screen_data" SAVE_DIR = "/tmp/model_output" # ────────────────────────────────────────────────────────────────── # SYNTHETIC DATA GENERATION # ────────────────────────────────────────────────────────────────── def gaussian_blob(h, w, cx, cy, sigma, intensity): y, x = np.ogrid[:h, :w] g = np.exp(-((x - cx)**2 + (y - cy)**2) / (2 * sigma**2)) return (g * intensity).astype(np.float32) def random_color(min_v=0, max_v=255): return tuple(random.randint(min_v, max_v) for _ in range(3)) def random_bright_color(): palettes = [ (66, 133, 244), (234, 67, 53), (251, 188, 4), (52, 168, 83), (255, 255, 255), (245, 245, 245), (33, 150, 243), (76, 175, 80), (255, 152, 0), (156, 39, 176), (0, 188, 212), (255, 87, 34), (63, 81, 181), (0, 150, 136), (121, 85, 72), ] return random.choice(palettes) def add_noise(img_array, strength=5): noise = np.random.normal(0, strength, img_array.shape).astype(np.float32) return np.clip(img_array.astype(np.float32) + noise, 0, 255).astype(np.uint8) def generate_on_screen(size=(128, 128)): """Generate a realistic 'screen ON' image with varied UI layouts.""" w, h = size img = Image.new("RGB", (w, h)) draw = ImageDraw.Draw(img) style = random.choice([ "light_app", "dark_app", "media", "gradient_bg", "notification", "chat", "settings", "browser", "colorful_app", "lock_screen_on", ]) if style == "light_app": bg = random.choice([(255,255,255), (245,245,245), (250,250,250)]) draw.rectangle([0, 0, w, h], fill=bg) bar_h = random.randint(8, 14) draw.rectangle([0, 0, w, bar_h], fill=random.choice([(50,50,50), (33,33,33), (66,133,244)])) for x_pos in [5, 12, 19]: draw.rectangle([x_pos, 2, x_pos+4, bar_h-2], fill=(200,200,200)) for x_pos in [w-25, w-18, w-10]: draw.rectangle([x_pos, 2, x_pos+5, bar_h-2], fill=(200,200,200)) title_y = bar_h + 2 draw.rectangle([0, title_y, w, title_y+12], fill=random_bright_color()) y = title_y + 18 while y < h - 20: lw = random.randint(w//3, w-10) draw.rectangle([8, y, 8+lw, y+random.randint(2,4)], fill=(random.randint(60,160),)*3) y += random.randint(6, 14) if random.random() > 0.4: btn_y = random.randint(h//2, h-20) draw.rounded_rectangle([w//4, btn_y, 3*w//4, btn_y+12], radius=4, fill=random_bright_color()) elif style == "dark_app": bg = random.choice([(18,18,18), (30,30,30), (20,20,30), (25,25,35)]) draw.rectangle([0, 0, w, h], fill=bg) bar_h = random.randint(8, 14) draw.rectangle([0, 0, w, bar_h], fill=(35,35,35)) accent = random_bright_color() nav_h = random.randint(10, 16) draw.rectangle([0, h-nav_h, w, h], fill=(30,30,30)) n_tabs = random.randint(3, 5) tab_w = w // n_tabs active = random.randint(0, n_tabs-1) for i in range(n_tabs): x = i * tab_w + tab_w//2 - 3 draw.rectangle([x, h-nav_h+3, x+6, h-3], fill=accent if i == active else (100,100,100)) y = bar_h + 8 while y < h - nav_h - 10: card_h = random.randint(15, 30) draw.rounded_rectangle([6, y, w-6, y+card_h], radius=3, fill=(45,45,50)) for ly in range(y+4, min(y+card_h-4, h), 5): draw.rectangle([10, ly, 10+random.randint(20, w-20), ly+2], fill=(160,160,160)) y += card_h + 4 elif style == "media": arr = np.zeros((h, w, 3), dtype=np.uint8) c1 = np.array(random_bright_color(), dtype=np.float32) c2 = np.array(random_bright_color(), dtype=np.float32) for row in range(h): arr[row] = (c1 * (1 - row/h) + c2 * (row/h)).astype(np.uint8) img = Image.fromarray(arr) draw = ImageDraw.Draw(img) if random.random() > 0.3: cx, cy = w//2, h//2 s = random.randint(10, 20) draw.polygon([(cx-s//2, cy-s), (cx-s//2, cy+s), (cx+s, cy)], fill=(255,255,255)) bar_y = h - random.randint(15, 25) p = random.uniform(0.1, 0.9) draw.rectangle([5, bar_y, w-5, bar_y+3], fill=(100,100,100)) draw.rectangle([5, bar_y, 5+int((w-10)*p), bar_y+3], fill=(255,0,0)) elif style == "gradient_bg": arr = np.zeros((h, w, 3), dtype=np.float32) c1 = np.array(random_bright_color(), dtype=np.float32) c2 = np.array(random_bright_color(), dtype=np.float32) for row in range(h): for col in range(w): t = (row/h + col/w) / 2 arr[row, col] = c1*(1-t) + c2*t img = Image.fromarray(arr.astype(np.uint8)) draw = ImageDraw.Draw(img) icon_size = random.randint(10, 16) cols_n = random.randint(3, 5) rows_n = random.randint(3, 5) x_start = (w - cols_n * (icon_size + 6)) // 2 y_start = random.randint(20, 35) for r in range(rows_n): for c in range(cols_n): x = x_start + c*(icon_size+6) y = y_start + r*(icon_size+10) if y + icon_size < h - 15: draw.rounded_rectangle([x, y, x+icon_size, y+icon_size], radius=3, fill=random_bright_color()) elif style == "notification": draw.rectangle([0, 0, w, h], fill=random.choice([(30,30,40), (20,20,25), (40,40,50)])) draw.rectangle([w//4, 10, 3*w//4, 28], fill=(220,220,220)) y = 35 for _ in range(random.randint(2, 5)): if y > h - 20: break ch = random.randint(18, 30) draw.rounded_rectangle([8, y, w-8, y+ch], radius=4, fill=(50,50,60)) draw.ellipse([12, y+4, 22, y+14], fill=random_bright_color()) draw.rectangle([26, y+5, w-15, y+8], fill=(200,200,200)) draw.rectangle([26, y+11, random.randint(w//2, w-15), y+14], fill=(150,150,150)) y += ch + 5 elif style == "chat": is_dark = random.random() > 0.5 bg = (18,18,22) if is_dark else random.choice([(230,230,235), (255,255,255)]) draw.rectangle([0, 0, w, h], fill=bg) draw.rectangle([0, 0, w, 16], fill=random_bright_color()) y = 22 while y < h - 25: is_sent = random.random() > 0.5 bw = random.randint(w//3, 2*w//3) bh = random.randint(10, 22) x1 = w - bw - 8 if is_sent else 8 color = (0,132,255) if is_sent else ((50,50,55) if is_dark else (229,229,234)) draw.rounded_rectangle([x1, y, x1+bw, y+bh], radius=5, fill=color) text_c = (255,255,255) if is_sent or is_dark else (30,30,30) for ly in range(y+3, min(y+bh-3, h), 5): draw.rectangle([x1+5, ly, x1+5+random.randint(bw//3, bw-8), ly+2], fill=text_c) y += bh + random.randint(4, 10) draw.rectangle([0, h-14, w, h], fill=(35,35,40) if is_dark else (240,240,240)) elif style == "settings": is_dark = random.random() > 0.5 bg = (0,0,0) if is_dark else (242,242,247) draw.rectangle([0, 0, w, h], fill=bg) bar_h = 16 draw.rectangle([w//4, 4, 3*w//4, 12], fill=(255,255,255) if is_dark else (0,0,0)) y = bar_h + 4 row_bg = (28,28,30) if is_dark else (255,255,255) while y < h - 8: rh = random.randint(12, 18) draw.rectangle([0, y, w, y+rh], fill=row_bg) draw.rounded_rectangle([8, y+3, 18, y+rh-3], radius=2, fill=random_bright_color()) label_c = (220,220,220) if is_dark else (30,30,30) draw.rectangle([24, y+rh//2-1, 24+random.randint(30, w-40), y+rh//2+1], fill=label_c) y += rh elif style == "browser": draw.rectangle([0, 0, w, h], fill=(255,255,255)) bar_h = 14 draw.rectangle([0, 0, w, bar_h], fill=(245,245,245)) draw.rounded_rectangle([8, 2, w-8, bar_h-2], radius=4, fill=(255,255,255), outline=(200,200,200)) draw.rectangle([14, 5, w//2, 9], fill=(100,100,100)) y = bar_h + 6 while y < h - 10: elem = random.choice(["text", "image", "heading"]) if elem == "text": for _ in range(random.randint(2, 5)): if y > h-8: break draw.rectangle([8, y, 8+random.randint(w//2, w-12), y+2], fill=(50,50,50)) y += 5 elif elem == "image": ih = random.randint(20, 40) draw.rectangle([8, y, w-8, y+ih], fill=random_color(120, 240)) y += ih + 4 else: draw.rectangle([8, y, w//2+20, y+5], fill=(20,20,20)) y += 10 y += random.randint(4, 10) elif style == "colorful_app": arr = np.random.randint(60, 255, (h, w, 3), dtype=np.uint8) img = Image.fromarray(arr).filter(ImageFilter.GaussianBlur(radius=8)) draw = ImageDraw.Draw(img) draw.rectangle([0, 0, w, 12], fill=(0,0,0)) for x in [6, 14, 22]: draw.rectangle([x, 3, x+5, 9], fill=(255,255,255)) else: # lock_screen_on c1 = np.array(random_bright_color(), dtype=np.float32) c2 = np.array(random_bright_color(), dtype=np.float32) * 0.4 arr = np.zeros((h, w, 3), dtype=np.float32) for row in range(h): arr[row] = c1 * (1-row/h) + c2 * (row/h) img = Image.fromarray(arr.astype(np.uint8)).filter(ImageFilter.GaussianBlur(radius=3)) draw = ImageDraw.Draw(img) draw.rectangle([w//4, h//4, 3*w//4, h//4+20], fill=(255,255,255)) draw.rectangle([w//3, h//4+25, 2*w//3, h//4+30], fill=(220,220,220)) # Post-processing arr = np.array(img, dtype=np.float32) if arr.mean() < 80: arr = arr * (100 / max(arr.mean(), 1)) arr = np.clip(arr, 0, 255) # Vignette if random.random() > 0.3: vs = random.uniform(0.02, 0.08) y_v, x_v = np.ogrid[:h, :w] r = np.sqrt((x_v - w/2)**2 + (y_v - h/2)**2) r_max = np.sqrt((w/2)**2 + (h/2)**2) arr = arr * (1 - vs * (r / r_max)**2)[:, :, np.newaxis] # Glass reflection if random.random() > 0.5: rs = random.uniform(0.01, 0.04) blob = gaussian_blob(h, w, random.randint(0,w), random.randint(0,h), random.uniform(30,60), 255*rs) arr += blob[:, :, np.newaxis] arr = np.clip(arr, 0, 255).astype(np.uint8) arr = add_noise(arr, random.uniform(2, 6)) return Image.fromarray(arr) def generate_off_screen(size=(128, 128)): """Generate a realistic 'screen OFF' image with dark glass effects.""" w, h = size style = random.choice([ "dark_clean", "dark_glare", "reflection_heavy", "fingerprints", "ambient_bright", "sunset_reflection", "indoor_reflection", "near_black", "blue_ambient", ]) if style == "near_black": arr = np.random.uniform(0, 8, (h, w, 3)).astype(np.float32) elif style == "blue_ambient": arr = np.full((h, w, 3), [random.uniform(5,20), random.uniform(8,25), random.uniform(15,40)], dtype=np.float32) arr += np.random.uniform(-3, 3, (h, w, 3)) elif style == "ambient_bright": base = random.uniform(15, 45) arr = np.full((h, w, 3), base, dtype=np.float32) if random.random() > 0.5: arr[:,:,0] += random.uniform(5, 15) arr[:,:,2] -= random.uniform(0, 5) else: arr[:,:,2] += random.uniform(5, 15) arr[:,:,0] -= random.uniform(0, 5) elif style == "sunset_reflection": arr = np.zeros((h, w, 3), dtype=np.float32) for row in range(h): t = row / h arr[row,:,0] = random.uniform(20, 60) * (1-t) arr[row,:,1] = random.uniform(10, 30) * (1-t) arr[row,:,2] = random.uniform(5, 20) * (1-t) else: base = random.uniform(3, 25) arr = np.full((h, w, 3), base, dtype=np.float32) arr[:,:,0] += random.uniform(-3, 10) arr[:,:,1] += random.uniform(-3, 10) arr[:,:,2] += random.uniform(-3, 15) # Glare if style in ["dark_glare", "ambient_bright", "indoor_reflection"] or random.random() > 0.4: for _ in range(random.randint(1, 3)): blob = gaussian_blob(h, w, random.randint(0,w), random.randint(0,h), random.uniform(10,50), random.uniform(30,180)) glare_c = np.array([random.uniform(0.8,1), random.uniform(0.8,1), random.uniform(0.8,1)]) arr += blob[:,:,np.newaxis] * glare_c # Scene reflection if style in ["reflection_heavy", "indoor_reflection", "sunset_reflection"] or random.random() > 0.5: refl = np.zeros((h, w, 3), dtype=np.float32) for _ in range(random.randint(2, 6)): rx1, ry1 = random.randint(0, w-10), random.randint(0, h-10) rx2 = min(w, rx1 + random.randint(10, 50)) ry2 = min(h, ry1 + random.randint(10, 50)) refl[ry1:ry2, rx1:rx2] = np.array(random_color(20, 120), dtype=np.float32) refl_img = Image.fromarray(refl.astype(np.uint8)).filter( ImageFilter.GaussianBlur(radius=random.uniform(8, 20))) arr += np.array(refl_img, dtype=np.float32) * random.uniform(0.04, 0.15) # Fingerprints if style == "fingerprints" or random.random() > 0.5: for _ in range(random.randint(1, 4)): cx, cy = random.randint(10, w-10), random.randint(10, h-10) angle = random.uniform(0, math.pi) sx, sy = random.uniform(8, 25), random.uniform(5, 15) y_fp, x_fp = np.ogrid[:h, :w] dx = (x_fp-cx)*math.cos(angle) + (y_fp-cy)*math.sin(angle) dy = -(x_fp-cx)*math.sin(angle) + (y_fp-cy)*math.cos(angle) arr += (np.exp(-((dx/sx)**2 + (dy/sy)**2)*2) * random.uniform(3, 12))[:,:,np.newaxis] # Ambient gradient if random.random() > 0.3: direction = random.choice(["top", "bottom", "left", "right", "corner"]) if direction == "top": grad = np.linspace(1, 0, h)[:, np.newaxis] * np.ones(w) elif direction == "bottom": grad = np.linspace(0, 1, h)[:, np.newaxis] * np.ones(w) elif direction == "left": grad = np.ones((h, 1)) * np.linspace(1, 0, w) elif direction == "right": grad = np.ones((h, 1)) * np.linspace(0, 1, w) else: y_g, x_g = np.ogrid[:h, :w] grad = np.sqrt((x_g/w)**2 + (y_g/h)**2) / np.sqrt(2) arr += grad[:,:,np.newaxis] * random.uniform(3, 20) # Cap brightness arr = np.clip(arr, 0, 255) if arr.mean() > 70: arr = arr * (60 / arr.mean()) arr = np.clip(arr, 0, 255).astype(np.uint8) arr = add_noise(arr, random.uniform(2, 8)) return Image.fromarray(arr) def generate_dataset(output_dir, n_per_class=2000, size=(128, 128), seed=42): random.seed(seed) np.random.seed(seed) on_dir = Path(output_dir) / "on" off_dir = Path(output_dir) / "off" on_dir.mkdir(parents=True, exist_ok=True) off_dir.mkdir(parents=True, exist_ok=True) print(f"Generating {n_per_class} ON images...") for i in range(n_per_class): generate_on_screen(size).save(on_dir / f"on_{i:05d}.png") if (i + 1) % 500 == 0: print(f" ON: {i+1}/{n_per_class}") print(f"Generating {n_per_class} OFF images...") for i in range(n_per_class): generate_off_screen(size).save(off_dir / f"off_{i:05d}.png") if (i + 1) % 500 == 0: print(f" OFF: {i+1}/{n_per_class}") print(f"Dataset: {n_per_class*2} images in {output_dir}") return output_dir # ────────────────────────────────────────────────────────────────── # MODEL # ────────────────────────────────────────────────────────────────── class ScreenClassifier(nn.Module): """ Tiny 3-layer CNN for binary screen ON/OFF classification. Input: [B, 1, 64, 64] grayscale | Output: [B, 1] logit """ def __init__(self): super().__init__() self.features = nn.Sequential( nn.Conv2d(1, 16, 3, padding=1, bias=False), nn.BatchNorm2d(16), nn.ReLU(True), nn.MaxPool2d(2), nn.Conv2d(16, 32, 3, padding=1, bias=False), nn.BatchNorm2d(32), nn.ReLU(True), nn.MaxPool2d(2), nn.Conv2d(32, 64, 3, padding=1, bias=False), nn.BatchNorm2d(64), nn.ReLU(True), nn.AdaptiveAvgPool2d(1), ) self.classifier = nn.Sequential(nn.Flatten(), nn.Dropout(0.3), nn.Linear(64, 1)) self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") elif isinstance(m, nn.BatchNorm2d): nn.init.ones_(m.weight); nn.init.zeros_(m.bias) elif isinstance(m, nn.Linear): nn.init.kaiming_normal_(m.weight, mode="fan_in", nonlinearity="relu") if m.bias is not None: nn.init.zeros_(m.bias) def forward(self, x): return self.classifier(self.features(x)) # ────────────────────────────────────────────────────────────────── # EARLY STOPPING # ────────────────────────────────────────────────────────────────── class EarlyStopping: def __init__(self, patience=10, min_delta=1e-4): self.patience = patience self.min_delta = min_delta self.best_loss = float("inf") self.best_state = None self.counter = 0 self.should_stop = False def step(self, val_loss, model): if val_loss < self.best_loss - self.min_delta: self.best_loss = val_loss self.best_state = copy.deepcopy(model.state_dict()) self.counter = 0 else: self.counter += 1 if self.counter >= self.patience: self.should_stop = True return self.should_stop def restore_best(self, model): if self.best_state: model.load_state_dict(self.best_state) # ────────────────────────────────────────────────────────────────── # DATA LOADING # ────────────────────────────────────────────────────────────────── class _TransformSubset(torch.utils.data.Dataset): def __init__(self, dataset, indices, transform): self.dataset = dataset self.indices = list(indices) self.transform = transform def __len__(self): return len(self.indices) def __getitem__(self, idx): img_path, label = self.dataset.samples[self.indices[idx]] img = Image.open(img_path).convert("RGB") if self.transform: img = self.transform(img) return img, label def get_transforms(augment=True): if augment: return transforms.Compose([ transforms.Grayscale(1), transforms.Resize((72, 72)), transforms.RandomCrop((64, 64)), transforms.RandomHorizontalFlip(0.5), transforms.RandomRotation(10), transforms.RandomAffine(0, translate=(0.08, 0.08), scale=(0.9, 1.1)), transforms.ColorJitter(brightness=0.3, contrast=0.3), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ]) else: return transforms.Compose([ transforms.Grayscale(1), transforms.Resize((64, 64)), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ]) def build_dataloaders(data_dir, batch_size=32, val_split=0.2, seed=42): full_ds = datasets.ImageFolder(root=data_dir, transform=get_transforms(True)) print(f"[DATA] Classes: {full_ds.class_to_idx}, Total: {len(full_ds)}") n_total = len(full_ds) n_val = max(1, int(n_total * val_split)) n_train = n_total - n_val gen = torch.Generator().manual_seed(seed) train_idx, val_idx = random_split(range(n_total), [n_train, n_val], generator=gen) train_ds = _TransformSubset(full_ds, train_idx.indices, get_transforms(True)) val_ds = _TransformSubset(full_ds, val_idx.indices, get_transforms(False)) print(f"[DATA] Train: {len(train_ds)}, Val: {len(val_ds)}") train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=False) val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=False) return train_loader, val_loader, full_ds.class_to_idx # ────────────────────────────────────────────────────────────────── # TRAINING # ────────────────────────────────────────────────────────────────── def train_model(data_dir, epochs, batch_size, lr, patience, seed, save_dir): torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) os.makedirs(save_dir, exist_ok=True) device = torch.device("cpu") print(f"[TRAIN] Device: {device}") train_loader, val_loader, class_to_idx = build_dataloaders(data_dir, batch_size, seed=seed) model = ScreenClassifier().to(device) n_params = sum(p.numel() for p in model.parameters()) print(f"[MODEL] Parameters: {n_params:,}") criterion = nn.BCEWithLogitsLoss() optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) early_stop = EarlyStopping(patience=patience) history = {"train_loss": [], "val_loss": [], "train_acc": [], "val_acc": []} best_val_acc = 0.0 for epoch in range(1, epochs + 1): # Train model.train() t_loss, t_correct, t_total = 0, 0, 0 for imgs, labels in train_loader: imgs, labels = imgs.to(device), labels.to(device).float() logits = model(imgs).squeeze(1) loss = criterion(logits, labels) optimizer.zero_grad(); loss.backward(); optimizer.step() t_loss += loss.item() * imgs.size(0) t_correct += ((torch.sigmoid(logits) >= 0.5).float() == labels).sum().item() t_total += imgs.size(0) t_loss /= t_total t_acc = t_correct / t_total # Validate model.eval() v_loss, v_correct, v_total = 0, 0, 0 all_preds, all_labels = [], [] with torch.no_grad(): for imgs, labels in val_loader: imgs, labels = imgs.to(device), labels.to(device).float() logits = model(imgs).squeeze(1) loss = criterion(logits, labels) v_loss += loss.item() * imgs.size(0) preds = (torch.sigmoid(logits) >= 0.5).float() v_correct += (preds == labels).sum().item() v_total += imgs.size(0) all_preds.extend(preds.cpu().numpy()) all_labels.extend(labels.cpu().numpy()) v_loss /= v_total v_acc = v_correct / v_total scheduler.step() lr_now = optimizer.param_groups[0]["lr"] history["train_loss"].append(t_loss) history["val_loss"].append(v_loss) history["train_acc"].append(t_acc) history["val_acc"].append(v_acc) if v_acc > best_val_acc: best_val_acc = v_acc print( f"Epoch {epoch:3d}/{epochs} | " f"Train Loss: {t_loss:.4f} Acc: {t_acc:.4f} | " f"Val Loss: {v_loss:.4f} Acc: {v_acc:.4f} | " f"LR: {lr_now:.6f} | ES: {early_stop.counter}/{early_stop.patience}" ) if early_stop.step(v_loss, model): print(f"\n[EARLY STOP] Epoch {epoch}. Restoring best weights.") early_stop.restore_best(model) break # ── Final evaluation ── model.eval() v_loss, v_correct, v_total = 0, 0, 0 tp, tn, fp, fn = 0, 0, 0, 0 with torch.no_grad(): for imgs, labels in val_loader: imgs, labels = imgs.to(device), labels.to(device).float() logits = model(imgs).squeeze(1) preds = (torch.sigmoid(logits) >= 0.5).float() v_correct += (preds == labels).sum().item() v_total += imgs.size(0) tp += ((preds == 1) & (labels == 1)).sum().item() tn += ((preds == 0) & (labels == 0)).sum().item() fp += ((preds == 1) & (labels == 0)).sum().item() fn += ((preds == 0) & (labels == 1)).sum().item() final_acc = v_correct / v_total precision = tp / (tp + fp) if (tp + fp) > 0 else 0 recall = tp / (tp + fn) if (tp + fn) > 0 else 0 f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 print(f"\n{'='*60}") print(f"FINAL EVALUATION") print(f"{'='*60}") print(f"Accuracy: {final_acc:.4f}") print(f"Precision: {precision:.4f}") print(f"Recall: {recall:.4f}") print(f"F1 Score: {f1:.4f}") print(f"TP={tp}, TN={tn}, FP={fp}, FN={fn}") print(f"{'='*60}") # ── Save ── pth_path = os.path.join(save_dir, "screen_classifier_best.pth") torch.save(model.state_dict(), pth_path) print(f"[SAVE] State dict: {pth_path}") pt_path = os.path.join(save_dir, "screen_classifier_best.pt") try: scripted = torch.jit.script(model) scripted.save(pt_path) print(f"[SAVE] TorchScript: {pt_path}") except Exception as e: print(f"[WARN] TorchScript failed: {e}") # Save metrics metrics = { "accuracy": final_acc, "precision": precision, "recall": recall, "f1": f1, "tp": tp, "tn": tn, "fp": fp, "fn": fn, "best_val_acc": best_val_acc, "n_params": n_params, "epochs_trained": epoch, "class_to_idx": class_to_idx, } with open(os.path.join(save_dir, "metrics.json"), "w") as f: json.dump(metrics, f, indent=2) # Benchmark dummy = torch.randn(1, 1, 64, 64) times = [] for _ in range(200): t0 = time.perf_counter() with torch.no_grad(): model(dummy) times.append(time.perf_counter() - t0) avg_ms = np.mean(times) * 1000 print(f"[BENCH] Inference: {avg_ms:.2f} ms avg (200 runs)") return model, metrics # ────────────────────────────────────────────────────────────────── # HUB UPLOAD # ────────────────────────────────────────────────────────────────── def push_to_hub(save_dir, hub_model_id, hub_dataset_id, data_dir): """Push model + dataset to Hugging Face Hub.""" from huggingface_hub import HfApi, upload_folder api = HfApi() token = os.environ.get("HF_TOKEN") # ── Upload model ── print(f"\n[HUB] Pushing model to {hub_model_id}...") try: api.create_repo(hub_model_id, exist_ok=True, token=token) except Exception: pass # Create model card with open(os.path.join(save_dir, "metrics.json")) as f: metrics = json.load(f) model_card = f"""--- tags: - pytorch - image-classification - binary-classification - screen-detection - lightweight - cpu-optimized license: mit metrics: - accuracy - f1 --- # Screen ON/OFF Classifier Ultra-lightweight CNN (~23K params) that classifies phone screen images as **ON** or **OFF**. Designed for real-time CPU inference (<1ms per frame). ## Performance | Metric | Value | |--------|-------| | Accuracy | {metrics['accuracy']:.4f} | | Precision | {metrics['precision']:.4f} | | Recall | {metrics['recall']:.4f} | | F1 Score | {metrics['f1']:.4f} | | Parameters | {metrics['n_params']:,} | | Inference | <1ms (CPU) | ## Usage ```python import numpy as np import cv2 import torch import torch.nn as nn class ScreenClassifier(nn.Module): def __init__(self): super().__init__() self.features = nn.Sequential( nn.Conv2d(1, 16, 3, padding=1, bias=False), nn.BatchNorm2d(16), nn.ReLU(True), nn.MaxPool2d(2), nn.Conv2d(16, 32, 3, padding=1, bias=False), nn.BatchNorm2d(32), nn.ReLU(True), nn.MaxPool2d(2), nn.Conv2d(32, 64, 3, padding=1, bias=False), nn.BatchNorm2d(64), nn.ReLU(True), nn.AdaptiveAvgPool2d(1), ) self.classifier = nn.Sequential(nn.Flatten(), nn.Dropout(0.3), nn.Linear(64, 1)) def forward(self, x): return self.classifier(self.features(x)) # Load model = ScreenClassifier() model.load_state_dict(torch.load("screen_classifier_best.pth", map_location="cpu", weights_only=True)) model.eval() # Predict from OpenCV frame frame = cv2.imread("phone_screen.jpg") gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) resized = cv2.resize(gray, (64, 64), interpolation=cv2.INTER_AREA) tensor = torch.from_numpy(resized.astype(np.float32)).div(255.0) tensor = (tensor - 0.5) / 0.5 tensor = tensor.unsqueeze(0).unsqueeze(0) with torch.no_grad(): prob = torch.sigmoid(model(tensor).squeeze()).item() label = "ON" if prob >= 0.5 else "OFF" confidence = prob if label == "ON" else 1.0 - prob print(f"{{label}} (confidence: {{confidence:.1%}})") ``` ## Training Trained on synthetic data ({metrics.get('n_params', 'N/A')} params) with domain randomization. - ON screens: 10 UI layout styles (light/dark apps, chat, media, settings, browser, etc.) - OFF screens: 9 dark glass variations (glare, reflections, fingerprints, ambient lighting) ## Files - `screen_classifier_best.pth` — PyTorch state dict - `screen_classifier_best.pt` — TorchScript (deploy without class definition) - `metrics.json` — Training metrics - `screen_classifier.py` — Full training + inference code """ with open(os.path.join(save_dir, "README.md"), "w") as f: f.write(model_card) upload_folder( folder_path=save_dir, repo_id=hub_model_id, token=token, commit_message="Upload trained screen ON/OFF classifier", ) print(f"[HUB] Model pushed: https://huggingface.co/{hub_model_id}") # ── Upload dataset ── print(f"[HUB] Pushing dataset to {hub_dataset_id}...") try: api.create_repo(hub_dataset_id, repo_type="dataset", exist_ok=True, token=token) except Exception: pass dataset_card = f"""--- tags: - image-classification - binary-classification - screen-detection - synthetic license: mit task_categories: - image-classification --- # Screen ON/OFF Dataset (Synthetic) Synthetic dataset for training phone screen ON/OFF binary classifiers. ## Structure - `on/` — {N_PER_CLASS} images of screens in ON state (various UI layouts) - `off/` — {N_PER_CLASS} images of screens in OFF state (dark glass with reflections, glare, fingerprints) ## Generation Generated using domain randomization (arxiv:1703.06907 principles): - **ON**: 10 UI styles (light/dark apps, chat, media, browser, settings, notifications, etc.) - **OFF**: 9 glass surface styles (clean dark, glare, reflections, fingerprints, ambient lighting) Images are 128x128 RGB PNG. Training resizes to 64x64 grayscale. """ dataset_readme = os.path.join(data_dir, "README.md") with open(dataset_readme, "w") as f: f.write(dataset_card) upload_folder( folder_path=data_dir, repo_id=hub_dataset_id, repo_type="dataset", token=token, commit_message="Upload synthetic screen ON/OFF dataset", ) print(f"[HUB] Dataset pushed: https://huggingface.co/datasets/{hub_dataset_id}") # ────────────────────────────────────────────────────────────────── # MAIN # ────────────────────────────────────────────────────────────────── def main(): print("=" * 60) print("SCREEN ON/OFF CLASSIFIER — FULL PIPELINE") print("=" * 60) # 1. Generate dataset print("\n[STEP 1] Generating synthetic dataset...") t0 = time.time() generate_dataset(DATA_DIR, N_PER_CLASS, (128, 128), SEED) print(f"Dataset generation: {time.time()-t0:.1f}s") # 2. Train model print(f"\n[STEP 2] Training model (epochs={EPOCHS}, bs={BATCH_SIZE}, lr={LR})...") t0 = time.time() model, metrics = train_model(DATA_DIR, EPOCHS, BATCH_SIZE, LR, PATIENCE, SEED, SAVE_DIR) print(f"Training: {time.time()-t0:.1f}s") # 3. Copy training script to save dir import shutil script_src = os.path.abspath(__file__) shutil.copy2(script_src, os.path.join(SAVE_DIR, "screen_classifier.py")) # 4. Push to Hub print(f"\n[STEP 3] Pushing to Hugging Face Hub...") try: push_to_hub(SAVE_DIR, HUB_MODEL_ID, HUB_DATASET_ID, DATA_DIR) except Exception as e: print(f"[WARN] Hub push failed: {e}") print("Model files are still saved locally.") print("\n" + "=" * 60) print("DONE!") print(f"Model: https://huggingface.co/{HUB_MODEL_ID}") print(f"Dataset: https://huggingface.co/datasets/{HUB_DATASET_ID}") print("=" * 60) if __name__ == "__main__": main()