| |
| """ |
| Complete Screen ON/OFF Training Pipeline |
| ========================================= |
| 1. Generates synthetic dataset (2000 images/class) |
| 2. Trains the lightweight CNN with early stopping |
| 3. Evaluates with detailed metrics |
| 4. Saves model (.pth + TorchScript .pt) |
| 5. Pushes model to Hugging Face Hub |
| 6. Uploads dataset to HF Hub |
| |
| Can be run standalone or via hf_jobs. |
| """ |
|
|
| import os |
| import sys |
| import time |
| import copy |
| import random |
| import math |
| import json |
| import argparse |
| from pathlib import Path |
| from typing import Tuple |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from torch.utils.data import DataLoader, random_split |
| from torchvision import datasets, transforms |
| from PIL import Image, ImageDraw, ImageFilter |
|
|
| |
| |
| |
|
|
| HUB_MODEL_ID = os.environ.get("HUB_MODEL_ID", "dhruvguptaa/screen-on-off-classifier") |
| HUB_DATASET_ID = os.environ.get("HUB_DATASET_ID", "dhruvguptaa/screen-on-off-dataset") |
| N_PER_CLASS = int(os.environ.get("N_PER_CLASS", "2000")) |
| EPOCHS = int(os.environ.get("EPOCHS", "80")) |
| BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "32")) |
| LR = float(os.environ.get("LR", "1e-3")) |
| PATIENCE = int(os.environ.get("PATIENCE", "10")) |
| SEED = int(os.environ.get("SEED", "42")) |
| DATA_DIR = "/tmp/screen_data" |
| SAVE_DIR = "/tmp/model_output" |
|
|
| |
| |
| |
|
|
| def gaussian_blob(h, w, cx, cy, sigma, intensity): |
| y, x = np.ogrid[:h, :w] |
| g = np.exp(-((x - cx)**2 + (y - cy)**2) / (2 * sigma**2)) |
| return (g * intensity).astype(np.float32) |
|
|
| def random_color(min_v=0, max_v=255): |
| return tuple(random.randint(min_v, max_v) for _ in range(3)) |
|
|
| def random_bright_color(): |
| palettes = [ |
| (66, 133, 244), (234, 67, 53), (251, 188, 4), (52, 168, 83), |
| (255, 255, 255), (245, 245, 245), (33, 150, 243), (76, 175, 80), |
| (255, 152, 0), (156, 39, 176), (0, 188, 212), (255, 87, 34), |
| (63, 81, 181), (0, 150, 136), (121, 85, 72), |
| ] |
| return random.choice(palettes) |
|
|
| def add_noise(img_array, strength=5): |
| noise = np.random.normal(0, strength, img_array.shape).astype(np.float32) |
| return np.clip(img_array.astype(np.float32) + noise, 0, 255).astype(np.uint8) |
|
|
|
|
| def generate_on_screen(size=(128, 128)): |
| """Generate a realistic 'screen ON' image with varied UI layouts.""" |
| w, h = size |
| img = Image.new("RGB", (w, h)) |
| draw = ImageDraw.Draw(img) |
|
|
| style = random.choice([ |
| "light_app", "dark_app", "media", "gradient_bg", "notification", |
| "chat", "settings", "browser", "colorful_app", "lock_screen_on", |
| ]) |
|
|
| if style == "light_app": |
| bg = random.choice([(255,255,255), (245,245,245), (250,250,250)]) |
| draw.rectangle([0, 0, w, h], fill=bg) |
| bar_h = random.randint(8, 14) |
| draw.rectangle([0, 0, w, bar_h], fill=random.choice([(50,50,50), (33,33,33), (66,133,244)])) |
| for x_pos in [5, 12, 19]: |
| draw.rectangle([x_pos, 2, x_pos+4, bar_h-2], fill=(200,200,200)) |
| for x_pos in [w-25, w-18, w-10]: |
| draw.rectangle([x_pos, 2, x_pos+5, bar_h-2], fill=(200,200,200)) |
| title_y = bar_h + 2 |
| draw.rectangle([0, title_y, w, title_y+12], fill=random_bright_color()) |
| y = title_y + 18 |
| while y < h - 20: |
| lw = random.randint(w//3, w-10) |
| draw.rectangle([8, y, 8+lw, y+random.randint(2,4)], fill=(random.randint(60,160),)*3) |
| y += random.randint(6, 14) |
| if random.random() > 0.4: |
| btn_y = random.randint(h//2, h-20) |
| draw.rounded_rectangle([w//4, btn_y, 3*w//4, btn_y+12], radius=4, fill=random_bright_color()) |
|
|
| elif style == "dark_app": |
| bg = random.choice([(18,18,18), (30,30,30), (20,20,30), (25,25,35)]) |
| draw.rectangle([0, 0, w, h], fill=bg) |
| bar_h = random.randint(8, 14) |
| draw.rectangle([0, 0, w, bar_h], fill=(35,35,35)) |
| accent = random_bright_color() |
| nav_h = random.randint(10, 16) |
| draw.rectangle([0, h-nav_h, w, h], fill=(30,30,30)) |
| n_tabs = random.randint(3, 5) |
| tab_w = w // n_tabs |
| active = random.randint(0, n_tabs-1) |
| for i in range(n_tabs): |
| x = i * tab_w + tab_w//2 - 3 |
| draw.rectangle([x, h-nav_h+3, x+6, h-3], fill=accent if i == active else (100,100,100)) |
| y = bar_h + 8 |
| while y < h - nav_h - 10: |
| card_h = random.randint(15, 30) |
| draw.rounded_rectangle([6, y, w-6, y+card_h], radius=3, fill=(45,45,50)) |
| for ly in range(y+4, min(y+card_h-4, h), 5): |
| draw.rectangle([10, ly, 10+random.randint(20, w-20), ly+2], fill=(160,160,160)) |
| y += card_h + 4 |
|
|
| elif style == "media": |
| arr = np.zeros((h, w, 3), dtype=np.uint8) |
| c1 = np.array(random_bright_color(), dtype=np.float32) |
| c2 = np.array(random_bright_color(), dtype=np.float32) |
| for row in range(h): |
| arr[row] = (c1 * (1 - row/h) + c2 * (row/h)).astype(np.uint8) |
| img = Image.fromarray(arr) |
| draw = ImageDraw.Draw(img) |
| if random.random() > 0.3: |
| cx, cy = w//2, h//2 |
| s = random.randint(10, 20) |
| draw.polygon([(cx-s//2, cy-s), (cx-s//2, cy+s), (cx+s, cy)], fill=(255,255,255)) |
| bar_y = h - random.randint(15, 25) |
| p = random.uniform(0.1, 0.9) |
| draw.rectangle([5, bar_y, w-5, bar_y+3], fill=(100,100,100)) |
| draw.rectangle([5, bar_y, 5+int((w-10)*p), bar_y+3], fill=(255,0,0)) |
|
|
| elif style == "gradient_bg": |
| arr = np.zeros((h, w, 3), dtype=np.float32) |
| c1 = np.array(random_bright_color(), dtype=np.float32) |
| c2 = np.array(random_bright_color(), dtype=np.float32) |
| for row in range(h): |
| for col in range(w): |
| t = (row/h + col/w) / 2 |
| arr[row, col] = c1*(1-t) + c2*t |
| img = Image.fromarray(arr.astype(np.uint8)) |
| draw = ImageDraw.Draw(img) |
| icon_size = random.randint(10, 16) |
| cols_n = random.randint(3, 5) |
| rows_n = random.randint(3, 5) |
| x_start = (w - cols_n * (icon_size + 6)) // 2 |
| y_start = random.randint(20, 35) |
| for r in range(rows_n): |
| for c in range(cols_n): |
| x = x_start + c*(icon_size+6) |
| y = y_start + r*(icon_size+10) |
| if y + icon_size < h - 15: |
| draw.rounded_rectangle([x, y, x+icon_size, y+icon_size], radius=3, fill=random_bright_color()) |
|
|
| elif style == "notification": |
| draw.rectangle([0, 0, w, h], fill=random.choice([(30,30,40), (20,20,25), (40,40,50)])) |
| draw.rectangle([w//4, 10, 3*w//4, 28], fill=(220,220,220)) |
| y = 35 |
| for _ in range(random.randint(2, 5)): |
| if y > h - 20: break |
| ch = random.randint(18, 30) |
| draw.rounded_rectangle([8, y, w-8, y+ch], radius=4, fill=(50,50,60)) |
| draw.ellipse([12, y+4, 22, y+14], fill=random_bright_color()) |
| draw.rectangle([26, y+5, w-15, y+8], fill=(200,200,200)) |
| draw.rectangle([26, y+11, random.randint(w//2, w-15), y+14], fill=(150,150,150)) |
| y += ch + 5 |
|
|
| elif style == "chat": |
| is_dark = random.random() > 0.5 |
| bg = (18,18,22) if is_dark else random.choice([(230,230,235), (255,255,255)]) |
| draw.rectangle([0, 0, w, h], fill=bg) |
| draw.rectangle([0, 0, w, 16], fill=random_bright_color()) |
| y = 22 |
| while y < h - 25: |
| is_sent = random.random() > 0.5 |
| bw = random.randint(w//3, 2*w//3) |
| bh = random.randint(10, 22) |
| x1 = w - bw - 8 if is_sent else 8 |
| color = (0,132,255) if is_sent else ((50,50,55) if is_dark else (229,229,234)) |
| draw.rounded_rectangle([x1, y, x1+bw, y+bh], radius=5, fill=color) |
| text_c = (255,255,255) if is_sent or is_dark else (30,30,30) |
| for ly in range(y+3, min(y+bh-3, h), 5): |
| draw.rectangle([x1+5, ly, x1+5+random.randint(bw//3, bw-8), ly+2], fill=text_c) |
| y += bh + random.randint(4, 10) |
| draw.rectangle([0, h-14, w, h], fill=(35,35,40) if is_dark else (240,240,240)) |
|
|
| elif style == "settings": |
| is_dark = random.random() > 0.5 |
| bg = (0,0,0) if is_dark else (242,242,247) |
| draw.rectangle([0, 0, w, h], fill=bg) |
| bar_h = 16 |
| draw.rectangle([w//4, 4, 3*w//4, 12], fill=(255,255,255) if is_dark else (0,0,0)) |
| y = bar_h + 4 |
| row_bg = (28,28,30) if is_dark else (255,255,255) |
| while y < h - 8: |
| rh = random.randint(12, 18) |
| draw.rectangle([0, y, w, y+rh], fill=row_bg) |
| draw.rounded_rectangle([8, y+3, 18, y+rh-3], radius=2, fill=random_bright_color()) |
| label_c = (220,220,220) if is_dark else (30,30,30) |
| draw.rectangle([24, y+rh//2-1, 24+random.randint(30, w-40), y+rh//2+1], fill=label_c) |
| y += rh |
|
|
| elif style == "browser": |
| draw.rectangle([0, 0, w, h], fill=(255,255,255)) |
| bar_h = 14 |
| draw.rectangle([0, 0, w, bar_h], fill=(245,245,245)) |
| draw.rounded_rectangle([8, 2, w-8, bar_h-2], radius=4, fill=(255,255,255), outline=(200,200,200)) |
| draw.rectangle([14, 5, w//2, 9], fill=(100,100,100)) |
| y = bar_h + 6 |
| while y < h - 10: |
| elem = random.choice(["text", "image", "heading"]) |
| if elem == "text": |
| for _ in range(random.randint(2, 5)): |
| if y > h-8: break |
| draw.rectangle([8, y, 8+random.randint(w//2, w-12), y+2], fill=(50,50,50)) |
| y += 5 |
| elif elem == "image": |
| ih = random.randint(20, 40) |
| draw.rectangle([8, y, w-8, y+ih], fill=random_color(120, 240)) |
| y += ih + 4 |
| else: |
| draw.rectangle([8, y, w//2+20, y+5], fill=(20,20,20)) |
| y += 10 |
| y += random.randint(4, 10) |
|
|
| elif style == "colorful_app": |
| arr = np.random.randint(60, 255, (h, w, 3), dtype=np.uint8) |
| img = Image.fromarray(arr).filter(ImageFilter.GaussianBlur(radius=8)) |
| draw = ImageDraw.Draw(img) |
| draw.rectangle([0, 0, w, 12], fill=(0,0,0)) |
| for x in [6, 14, 22]: |
| draw.rectangle([x, 3, x+5, 9], fill=(255,255,255)) |
|
|
| else: |
| c1 = np.array(random_bright_color(), dtype=np.float32) |
| c2 = np.array(random_bright_color(), dtype=np.float32) * 0.4 |
| arr = np.zeros((h, w, 3), dtype=np.float32) |
| for row in range(h): |
| arr[row] = c1 * (1-row/h) + c2 * (row/h) |
| img = Image.fromarray(arr.astype(np.uint8)).filter(ImageFilter.GaussianBlur(radius=3)) |
| draw = ImageDraw.Draw(img) |
| draw.rectangle([w//4, h//4, 3*w//4, h//4+20], fill=(255,255,255)) |
| draw.rectangle([w//3, h//4+25, 2*w//3, h//4+30], fill=(220,220,220)) |
|
|
| |
| arr = np.array(img, dtype=np.float32) |
| if arr.mean() < 80: |
| arr = arr * (100 / max(arr.mean(), 1)) |
| arr = np.clip(arr, 0, 255) |
|
|
| |
| if random.random() > 0.3: |
| vs = random.uniform(0.02, 0.08) |
| y_v, x_v = np.ogrid[:h, :w] |
| r = np.sqrt((x_v - w/2)**2 + (y_v - h/2)**2) |
| r_max = np.sqrt((w/2)**2 + (h/2)**2) |
| arr = arr * (1 - vs * (r / r_max)**2)[:, :, np.newaxis] |
|
|
| |
| if random.random() > 0.5: |
| rs = random.uniform(0.01, 0.04) |
| blob = gaussian_blob(h, w, random.randint(0,w), random.randint(0,h), random.uniform(30,60), 255*rs) |
| arr += blob[:, :, np.newaxis] |
|
|
| arr = np.clip(arr, 0, 255).astype(np.uint8) |
| arr = add_noise(arr, random.uniform(2, 6)) |
| return Image.fromarray(arr) |
|
|
|
|
| def generate_off_screen(size=(128, 128)): |
| """Generate a realistic 'screen OFF' image with dark glass effects.""" |
| w, h = size |
|
|
| style = random.choice([ |
| "dark_clean", "dark_glare", "reflection_heavy", "fingerprints", |
| "ambient_bright", "sunset_reflection", "indoor_reflection", |
| "near_black", "blue_ambient", |
| ]) |
|
|
| if style == "near_black": |
| arr = np.random.uniform(0, 8, (h, w, 3)).astype(np.float32) |
| elif style == "blue_ambient": |
| arr = np.full((h, w, 3), [random.uniform(5,20), random.uniform(8,25), random.uniform(15,40)], dtype=np.float32) |
| arr += np.random.uniform(-3, 3, (h, w, 3)) |
| elif style == "ambient_bright": |
| base = random.uniform(15, 45) |
| arr = np.full((h, w, 3), base, dtype=np.float32) |
| if random.random() > 0.5: |
| arr[:,:,0] += random.uniform(5, 15) |
| arr[:,:,2] -= random.uniform(0, 5) |
| else: |
| arr[:,:,2] += random.uniform(5, 15) |
| arr[:,:,0] -= random.uniform(0, 5) |
| elif style == "sunset_reflection": |
| arr = np.zeros((h, w, 3), dtype=np.float32) |
| for row in range(h): |
| t = row / h |
| arr[row,:,0] = random.uniform(20, 60) * (1-t) |
| arr[row,:,1] = random.uniform(10, 30) * (1-t) |
| arr[row,:,2] = random.uniform(5, 20) * (1-t) |
| else: |
| base = random.uniform(3, 25) |
| arr = np.full((h, w, 3), base, dtype=np.float32) |
| arr[:,:,0] += random.uniform(-3, 10) |
| arr[:,:,1] += random.uniform(-3, 10) |
| arr[:,:,2] += random.uniform(-3, 15) |
|
|
| |
| if style in ["dark_glare", "ambient_bright", "indoor_reflection"] or random.random() > 0.4: |
| for _ in range(random.randint(1, 3)): |
| blob = gaussian_blob(h, w, random.randint(0,w), random.randint(0,h), |
| random.uniform(10,50), random.uniform(30,180)) |
| glare_c = np.array([random.uniform(0.8,1), random.uniform(0.8,1), random.uniform(0.8,1)]) |
| arr += blob[:,:,np.newaxis] * glare_c |
|
|
| |
| if style in ["reflection_heavy", "indoor_reflection", "sunset_reflection"] or random.random() > 0.5: |
| refl = np.zeros((h, w, 3), dtype=np.float32) |
| for _ in range(random.randint(2, 6)): |
| rx1, ry1 = random.randint(0, w-10), random.randint(0, h-10) |
| rx2 = min(w, rx1 + random.randint(10, 50)) |
| ry2 = min(h, ry1 + random.randint(10, 50)) |
| refl[ry1:ry2, rx1:rx2] = np.array(random_color(20, 120), dtype=np.float32) |
| refl_img = Image.fromarray(refl.astype(np.uint8)).filter( |
| ImageFilter.GaussianBlur(radius=random.uniform(8, 20))) |
| arr += np.array(refl_img, dtype=np.float32) * random.uniform(0.04, 0.15) |
|
|
| |
| if style == "fingerprints" or random.random() > 0.5: |
| for _ in range(random.randint(1, 4)): |
| cx, cy = random.randint(10, w-10), random.randint(10, h-10) |
| angle = random.uniform(0, math.pi) |
| sx, sy = random.uniform(8, 25), random.uniform(5, 15) |
| y_fp, x_fp = np.ogrid[:h, :w] |
| dx = (x_fp-cx)*math.cos(angle) + (y_fp-cy)*math.sin(angle) |
| dy = -(x_fp-cx)*math.sin(angle) + (y_fp-cy)*math.cos(angle) |
| arr += (np.exp(-((dx/sx)**2 + (dy/sy)**2)*2) * random.uniform(3, 12))[:,:,np.newaxis] |
|
|
| |
| if random.random() > 0.3: |
| direction = random.choice(["top", "bottom", "left", "right", "corner"]) |
| if direction == "top": |
| grad = np.linspace(1, 0, h)[:, np.newaxis] * np.ones(w) |
| elif direction == "bottom": |
| grad = np.linspace(0, 1, h)[:, np.newaxis] * np.ones(w) |
| elif direction == "left": |
| grad = np.ones((h, 1)) * np.linspace(1, 0, w) |
| elif direction == "right": |
| grad = np.ones((h, 1)) * np.linspace(0, 1, w) |
| else: |
| y_g, x_g = np.ogrid[:h, :w] |
| grad = np.sqrt((x_g/w)**2 + (y_g/h)**2) / np.sqrt(2) |
| arr += grad[:,:,np.newaxis] * random.uniform(3, 20) |
|
|
| |
| arr = np.clip(arr, 0, 255) |
| if arr.mean() > 70: |
| arr = arr * (60 / arr.mean()) |
|
|
| arr = np.clip(arr, 0, 255).astype(np.uint8) |
| arr = add_noise(arr, random.uniform(2, 8)) |
| return Image.fromarray(arr) |
|
|
|
|
| def generate_dataset(output_dir, n_per_class=2000, size=(128, 128), seed=42): |
| random.seed(seed) |
| np.random.seed(seed) |
| on_dir = Path(output_dir) / "on" |
| off_dir = Path(output_dir) / "off" |
| on_dir.mkdir(parents=True, exist_ok=True) |
| off_dir.mkdir(parents=True, exist_ok=True) |
|
|
| print(f"Generating {n_per_class} ON images...") |
| for i in range(n_per_class): |
| generate_on_screen(size).save(on_dir / f"on_{i:05d}.png") |
| if (i + 1) % 500 == 0: |
| print(f" ON: {i+1}/{n_per_class}") |
|
|
| print(f"Generating {n_per_class} OFF images...") |
| for i in range(n_per_class): |
| generate_off_screen(size).save(off_dir / f"off_{i:05d}.png") |
| if (i + 1) % 500 == 0: |
| print(f" OFF: {i+1}/{n_per_class}") |
|
|
| print(f"Dataset: {n_per_class*2} images in {output_dir}") |
| return output_dir |
|
|
|
|
| |
| |
| |
|
|
| class ScreenClassifier(nn.Module): |
| """ |
| Tiny 3-layer CNN for binary screen ON/OFF classification. |
| Input: [B, 1, 64, 64] grayscale | Output: [B, 1] logit |
| """ |
| def __init__(self): |
| super().__init__() |
| self.features = nn.Sequential( |
| nn.Conv2d(1, 16, 3, padding=1, bias=False), nn.BatchNorm2d(16), nn.ReLU(True), nn.MaxPool2d(2), |
| nn.Conv2d(16, 32, 3, padding=1, bias=False), nn.BatchNorm2d(32), nn.ReLU(True), nn.MaxPool2d(2), |
| nn.Conv2d(32, 64, 3, padding=1, bias=False), nn.BatchNorm2d(64), nn.ReLU(True), nn.AdaptiveAvgPool2d(1), |
| ) |
| self.classifier = nn.Sequential(nn.Flatten(), nn.Dropout(0.3), nn.Linear(64, 1)) |
| self._init_weights() |
|
|
| def _init_weights(self): |
| for m in self.modules(): |
| if isinstance(m, nn.Conv2d): |
| nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") |
| elif isinstance(m, nn.BatchNorm2d): |
| nn.init.ones_(m.weight); nn.init.zeros_(m.bias) |
| elif isinstance(m, nn.Linear): |
| nn.init.kaiming_normal_(m.weight, mode="fan_in", nonlinearity="relu") |
| if m.bias is not None: nn.init.zeros_(m.bias) |
|
|
| def forward(self, x): |
| return self.classifier(self.features(x)) |
|
|
|
|
| |
| |
| |
|
|
| class EarlyStopping: |
| def __init__(self, patience=10, min_delta=1e-4): |
| self.patience = patience |
| self.min_delta = min_delta |
| self.best_loss = float("inf") |
| self.best_state = None |
| self.counter = 0 |
| self.should_stop = False |
|
|
| def step(self, val_loss, model): |
| if val_loss < self.best_loss - self.min_delta: |
| self.best_loss = val_loss |
| self.best_state = copy.deepcopy(model.state_dict()) |
| self.counter = 0 |
| else: |
| self.counter += 1 |
| if self.counter >= self.patience: |
| self.should_stop = True |
| return self.should_stop |
|
|
| def restore_best(self, model): |
| if self.best_state: |
| model.load_state_dict(self.best_state) |
|
|
|
|
| |
| |
| |
|
|
| class _TransformSubset(torch.utils.data.Dataset): |
| def __init__(self, dataset, indices, transform): |
| self.dataset = dataset |
| self.indices = list(indices) |
| self.transform = transform |
|
|
| def __len__(self): |
| return len(self.indices) |
|
|
| def __getitem__(self, idx): |
| img_path, label = self.dataset.samples[self.indices[idx]] |
| img = Image.open(img_path).convert("RGB") |
| if self.transform: |
| img = self.transform(img) |
| return img, label |
|
|
|
|
| def get_transforms(augment=True): |
| if augment: |
| return transforms.Compose([ |
| transforms.Grayscale(1), |
| transforms.Resize((72, 72)), |
| transforms.RandomCrop((64, 64)), |
| transforms.RandomHorizontalFlip(0.5), |
| transforms.RandomRotation(10), |
| transforms.RandomAffine(0, translate=(0.08, 0.08), scale=(0.9, 1.1)), |
| transforms.ColorJitter(brightness=0.3, contrast=0.3), |
| transforms.ToTensor(), |
| transforms.Normalize([0.5], [0.5]), |
| ]) |
| else: |
| return transforms.Compose([ |
| transforms.Grayscale(1), |
| transforms.Resize((64, 64)), |
| transforms.ToTensor(), |
| transforms.Normalize([0.5], [0.5]), |
| ]) |
|
|
|
|
| def build_dataloaders(data_dir, batch_size=32, val_split=0.2, seed=42): |
| full_ds = datasets.ImageFolder(root=data_dir, transform=get_transforms(True)) |
| print(f"[DATA] Classes: {full_ds.class_to_idx}, Total: {len(full_ds)}") |
|
|
| n_total = len(full_ds) |
| n_val = max(1, int(n_total * val_split)) |
| n_train = n_total - n_val |
|
|
| gen = torch.Generator().manual_seed(seed) |
| train_idx, val_idx = random_split(range(n_total), [n_train, n_val], generator=gen) |
|
|
| train_ds = _TransformSubset(full_ds, train_idx.indices, get_transforms(True)) |
| val_ds = _TransformSubset(full_ds, val_idx.indices, get_transforms(False)) |
|
|
| print(f"[DATA] Train: {len(train_ds)}, Val: {len(val_ds)}") |
|
|
| train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=False) |
| val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=False) |
| return train_loader, val_loader, full_ds.class_to_idx |
|
|
|
|
| |
| |
| |
|
|
| def train_model(data_dir, epochs, batch_size, lr, patience, seed, save_dir): |
| torch.manual_seed(seed) |
| np.random.seed(seed) |
| random.seed(seed) |
|
|
| os.makedirs(save_dir, exist_ok=True) |
| device = torch.device("cpu") |
| print(f"[TRAIN] Device: {device}") |
|
|
| train_loader, val_loader, class_to_idx = build_dataloaders(data_dir, batch_size, seed=seed) |
|
|
| model = ScreenClassifier().to(device) |
| n_params = sum(p.numel() for p in model.parameters()) |
| print(f"[MODEL] Parameters: {n_params:,}") |
|
|
| criterion = nn.BCEWithLogitsLoss() |
| optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4) |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) |
| early_stop = EarlyStopping(patience=patience) |
|
|
| history = {"train_loss": [], "val_loss": [], "train_acc": [], "val_acc": []} |
| best_val_acc = 0.0 |
|
|
| for epoch in range(1, epochs + 1): |
| |
| model.train() |
| t_loss, t_correct, t_total = 0, 0, 0 |
| for imgs, labels in train_loader: |
| imgs, labels = imgs.to(device), labels.to(device).float() |
| logits = model(imgs).squeeze(1) |
| loss = criterion(logits, labels) |
| optimizer.zero_grad(); loss.backward(); optimizer.step() |
| t_loss += loss.item() * imgs.size(0) |
| t_correct += ((torch.sigmoid(logits) >= 0.5).float() == labels).sum().item() |
| t_total += imgs.size(0) |
|
|
| t_loss /= t_total |
| t_acc = t_correct / t_total |
|
|
| |
| model.eval() |
| v_loss, v_correct, v_total = 0, 0, 0 |
| all_preds, all_labels = [], [] |
| with torch.no_grad(): |
| for imgs, labels in val_loader: |
| imgs, labels = imgs.to(device), labels.to(device).float() |
| logits = model(imgs).squeeze(1) |
| loss = criterion(logits, labels) |
| v_loss += loss.item() * imgs.size(0) |
| preds = (torch.sigmoid(logits) >= 0.5).float() |
| v_correct += (preds == labels).sum().item() |
| v_total += imgs.size(0) |
| all_preds.extend(preds.cpu().numpy()) |
| all_labels.extend(labels.cpu().numpy()) |
|
|
| v_loss /= v_total |
| v_acc = v_correct / v_total |
|
|
| scheduler.step() |
| lr_now = optimizer.param_groups[0]["lr"] |
|
|
| history["train_loss"].append(t_loss) |
| history["val_loss"].append(v_loss) |
| history["train_acc"].append(t_acc) |
| history["val_acc"].append(v_acc) |
|
|
| if v_acc > best_val_acc: |
| best_val_acc = v_acc |
|
|
| print( |
| f"Epoch {epoch:3d}/{epochs} | " |
| f"Train Loss: {t_loss:.4f} Acc: {t_acc:.4f} | " |
| f"Val Loss: {v_loss:.4f} Acc: {v_acc:.4f} | " |
| f"LR: {lr_now:.6f} | ES: {early_stop.counter}/{early_stop.patience}" |
| ) |
|
|
| if early_stop.step(v_loss, model): |
| print(f"\n[EARLY STOP] Epoch {epoch}. Restoring best weights.") |
| early_stop.restore_best(model) |
| break |
|
|
| |
| model.eval() |
| v_loss, v_correct, v_total = 0, 0, 0 |
| tp, tn, fp, fn = 0, 0, 0, 0 |
| with torch.no_grad(): |
| for imgs, labels in val_loader: |
| imgs, labels = imgs.to(device), labels.to(device).float() |
| logits = model(imgs).squeeze(1) |
| preds = (torch.sigmoid(logits) >= 0.5).float() |
| v_correct += (preds == labels).sum().item() |
| v_total += imgs.size(0) |
| tp += ((preds == 1) & (labels == 1)).sum().item() |
| tn += ((preds == 0) & (labels == 0)).sum().item() |
| fp += ((preds == 1) & (labels == 0)).sum().item() |
| fn += ((preds == 0) & (labels == 1)).sum().item() |
|
|
| final_acc = v_correct / v_total |
| precision = tp / (tp + fp) if (tp + fp) > 0 else 0 |
| recall = tp / (tp + fn) if (tp + fn) > 0 else 0 |
| f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 |
|
|
| print(f"\n{'='*60}") |
| print(f"FINAL EVALUATION") |
| print(f"{'='*60}") |
| print(f"Accuracy: {final_acc:.4f}") |
| print(f"Precision: {precision:.4f}") |
| print(f"Recall: {recall:.4f}") |
| print(f"F1 Score: {f1:.4f}") |
| print(f"TP={tp}, TN={tn}, FP={fp}, FN={fn}") |
| print(f"{'='*60}") |
|
|
| |
| pth_path = os.path.join(save_dir, "screen_classifier_best.pth") |
| torch.save(model.state_dict(), pth_path) |
| print(f"[SAVE] State dict: {pth_path}") |
|
|
| pt_path = os.path.join(save_dir, "screen_classifier_best.pt") |
| try: |
| scripted = torch.jit.script(model) |
| scripted.save(pt_path) |
| print(f"[SAVE] TorchScript: {pt_path}") |
| except Exception as e: |
| print(f"[WARN] TorchScript failed: {e}") |
|
|
| |
| metrics = { |
| "accuracy": final_acc, "precision": precision, "recall": recall, "f1": f1, |
| "tp": tp, "tn": tn, "fp": fp, "fn": fn, |
| "best_val_acc": best_val_acc, "n_params": n_params, |
| "epochs_trained": epoch, "class_to_idx": class_to_idx, |
| } |
| with open(os.path.join(save_dir, "metrics.json"), "w") as f: |
| json.dump(metrics, f, indent=2) |
|
|
| |
| dummy = torch.randn(1, 1, 64, 64) |
| times = [] |
| for _ in range(200): |
| t0 = time.perf_counter() |
| with torch.no_grad(): |
| model(dummy) |
| times.append(time.perf_counter() - t0) |
| avg_ms = np.mean(times) * 1000 |
| print(f"[BENCH] Inference: {avg_ms:.2f} ms avg (200 runs)") |
|
|
| return model, metrics |
|
|
|
|
| |
| |
| |
|
|
| def push_to_hub(save_dir, hub_model_id, hub_dataset_id, data_dir): |
| """Push model + dataset to Hugging Face Hub.""" |
| from huggingface_hub import HfApi, upload_folder |
|
|
| api = HfApi() |
| token = os.environ.get("HF_TOKEN") |
|
|
| |
| print(f"\n[HUB] Pushing model to {hub_model_id}...") |
| try: |
| api.create_repo(hub_model_id, exist_ok=True, token=token) |
| except Exception: |
| pass |
|
|
| |
| with open(os.path.join(save_dir, "metrics.json")) as f: |
| metrics = json.load(f) |
|
|
| model_card = f"""--- |
| tags: |
| - pytorch |
| - image-classification |
| - binary-classification |
| - screen-detection |
| - lightweight |
| - cpu-optimized |
| license: mit |
| metrics: |
| - accuracy |
| - f1 |
| --- |
| |
| # Screen ON/OFF Classifier |
| |
| Ultra-lightweight CNN (~23K params) that classifies phone screen images as **ON** or **OFF**. |
| Designed for real-time CPU inference (<1ms per frame). |
| |
| ## Performance |
| |
| | Metric | Value | |
| |--------|-------| |
| | Accuracy | {metrics['accuracy']:.4f} | |
| | Precision | {metrics['precision']:.4f} | |
| | Recall | {metrics['recall']:.4f} | |
| | F1 Score | {metrics['f1']:.4f} | |
| | Parameters | {metrics['n_params']:,} | |
| | Inference | <1ms (CPU) | |
| |
| ## Usage |
| |
| ```python |
| import numpy as np |
| import cv2 |
| import torch |
| import torch.nn as nn |
| |
| class ScreenClassifier(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.features = nn.Sequential( |
| nn.Conv2d(1, 16, 3, padding=1, bias=False), nn.BatchNorm2d(16), nn.ReLU(True), nn.MaxPool2d(2), |
| nn.Conv2d(16, 32, 3, padding=1, bias=False), nn.BatchNorm2d(32), nn.ReLU(True), nn.MaxPool2d(2), |
| nn.Conv2d(32, 64, 3, padding=1, bias=False), nn.BatchNorm2d(64), nn.ReLU(True), nn.AdaptiveAvgPool2d(1), |
| ) |
| self.classifier = nn.Sequential(nn.Flatten(), nn.Dropout(0.3), nn.Linear(64, 1)) |
| |
| def forward(self, x): |
| return self.classifier(self.features(x)) |
| |
| # Load |
| model = ScreenClassifier() |
| model.load_state_dict(torch.load("screen_classifier_best.pth", map_location="cpu", weights_only=True)) |
| model.eval() |
| |
| # Predict from OpenCV frame |
| frame = cv2.imread("phone_screen.jpg") |
| gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) |
| resized = cv2.resize(gray, (64, 64), interpolation=cv2.INTER_AREA) |
| tensor = torch.from_numpy(resized.astype(np.float32)).div(255.0) |
| tensor = (tensor - 0.5) / 0.5 |
| tensor = tensor.unsqueeze(0).unsqueeze(0) |
| |
| with torch.no_grad(): |
| prob = torch.sigmoid(model(tensor).squeeze()).item() |
| |
| label = "ON" if prob >= 0.5 else "OFF" |
| confidence = prob if label == "ON" else 1.0 - prob |
| print(f"{{label}} (confidence: {{confidence:.1%}})") |
| ``` |
| |
| ## Training |
| |
| Trained on synthetic data ({metrics.get('n_params', 'N/A')} params) with domain randomization. |
| - ON screens: 10 UI layout styles (light/dark apps, chat, media, settings, browser, etc.) |
| - OFF screens: 9 dark glass variations (glare, reflections, fingerprints, ambient lighting) |
| |
| ## Files |
| |
| - `screen_classifier_best.pth` β PyTorch state dict |
| - `screen_classifier_best.pt` β TorchScript (deploy without class definition) |
| - `metrics.json` β Training metrics |
| - `screen_classifier.py` β Full training + inference code |
| """ |
|
|
| with open(os.path.join(save_dir, "README.md"), "w") as f: |
| f.write(model_card) |
|
|
| upload_folder( |
| folder_path=save_dir, |
| repo_id=hub_model_id, |
| token=token, |
| commit_message="Upload trained screen ON/OFF classifier", |
| ) |
| print(f"[HUB] Model pushed: https://huggingface.co/{hub_model_id}") |
|
|
| |
| print(f"[HUB] Pushing dataset to {hub_dataset_id}...") |
| try: |
| api.create_repo(hub_dataset_id, repo_type="dataset", exist_ok=True, token=token) |
| except Exception: |
| pass |
|
|
| dataset_card = f"""--- |
| tags: |
| - image-classification |
| - binary-classification |
| - screen-detection |
| - synthetic |
| license: mit |
| task_categories: |
| - image-classification |
| --- |
| |
| # Screen ON/OFF Dataset (Synthetic) |
| |
| Synthetic dataset for training phone screen ON/OFF binary classifiers. |
| |
| ## Structure |
| - `on/` β {N_PER_CLASS} images of screens in ON state (various UI layouts) |
| - `off/` β {N_PER_CLASS} images of screens in OFF state (dark glass with reflections, glare, fingerprints) |
| |
| ## Generation |
| Generated using domain randomization (arxiv:1703.06907 principles): |
| - **ON**: 10 UI styles (light/dark apps, chat, media, browser, settings, notifications, etc.) |
| - **OFF**: 9 glass surface styles (clean dark, glare, reflections, fingerprints, ambient lighting) |
| |
| Images are 128x128 RGB PNG. Training resizes to 64x64 grayscale. |
| """ |
| dataset_readme = os.path.join(data_dir, "README.md") |
| with open(dataset_readme, "w") as f: |
| f.write(dataset_card) |
|
|
| upload_folder( |
| folder_path=data_dir, |
| repo_id=hub_dataset_id, |
| repo_type="dataset", |
| token=token, |
| commit_message="Upload synthetic screen ON/OFF dataset", |
| ) |
| print(f"[HUB] Dataset pushed: https://huggingface.co/datasets/{hub_dataset_id}") |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| print("=" * 60) |
| print("SCREEN ON/OFF CLASSIFIER β FULL PIPELINE") |
| print("=" * 60) |
|
|
| |
| print("\n[STEP 1] Generating synthetic dataset...") |
| t0 = time.time() |
| generate_dataset(DATA_DIR, N_PER_CLASS, (128, 128), SEED) |
| print(f"Dataset generation: {time.time()-t0:.1f}s") |
|
|
| |
| print(f"\n[STEP 2] Training model (epochs={EPOCHS}, bs={BATCH_SIZE}, lr={LR})...") |
| t0 = time.time() |
| model, metrics = train_model(DATA_DIR, EPOCHS, BATCH_SIZE, LR, PATIENCE, SEED, SAVE_DIR) |
| print(f"Training: {time.time()-t0:.1f}s") |
|
|
| |
| import shutil |
| script_src = os.path.abspath(__file__) |
| shutil.copy2(script_src, os.path.join(SAVE_DIR, "screen_classifier.py")) |
|
|
| |
| print(f"\n[STEP 3] Pushing to Hugging Face Hub...") |
| try: |
| push_to_hub(SAVE_DIR, HUB_MODEL_ID, HUB_DATASET_ID, DATA_DIR) |
| except Exception as e: |
| print(f"[WARN] Hub push failed: {e}") |
| print("Model files are still saved locally.") |
|
|
| print("\n" + "=" * 60) |
| print("DONE!") |
| print(f"Model: https://huggingface.co/{HUB_MODEL_ID}") |
| print(f"Dataset: https://huggingface.co/datasets/{HUB_DATASET_ID}") |
| print("=" * 60) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|