import json import os import time import threading from dataclasses import dataclass, asdict import psutil import requests import torch from PIL import Image from transformers import AutoModelForImageTextToText, AutoProcessor MODEL_ID = "Dharunkumar9/SmolVLM-256M-Instruct-Agri" OUT_JSON = os.path.join(os.path.dirname(__file__), "benchmark_results.json") SAMPLE_IMAGE_URL = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/cats.png" @dataclass class CaseResult: name: str input_tokens: int generated_tokens: int latency_s: float tokens_per_s: float peak_rss_mb: float output_preview: str class MemoryMonitor: def __init__(self, process: psutil.Process, interval_s: float = 0.01): self.process = process self.interval_s = interval_s self._running = False self._thread = None self.max_rss = 0 def _run(self): while self._running: rss = self.process.memory_info().rss if rss > self.max_rss: self.max_rss = rss time.sleep(self.interval_s) def __enter__(self): self._running = True self.max_rss = self.process.memory_info().rss self._thread = threading.Thread(target=self._run, daemon=True) self._thread.start() return self def __exit__(self, exc_type, exc, tb): self._running = False if self._thread is not None: self._thread.join(timeout=1) def pick_device(): if torch.backends.mps.is_available(): return "mps", torch.float16 if torch.cuda.is_available(): return "cuda", torch.bfloat16 return "cpu", torch.float32 def make_prompt(processor, text: str, with_image: bool): if with_image: messages = [ { "role": "user", "content": [ {"type": "image"}, {"type": "text", "text": text}, ], } ] else: messages = [{"role": "user", "content": [{"type": "text", "text": text}]}] return processor.apply_chat_template(messages, add_generation_prompt=True) def prepare_inputs(processor, prompt: str, image: Image.Image | None, device: str): kwargs = {"text": prompt, "return_tensors": "pt"} if image is not None: kwargs["images"] = [image] inputs = processor(**kwargs) return {k: (v.to(device) if torch.is_tensor(v) else v) for k, v in inputs.items()} def run_case(model, processor, device: str, case_name: str, text: str, image: Image.Image | None, max_new_tokens: int = 64): prompt = make_prompt(processor, text, with_image=image is not None) inputs = prepare_inputs(processor, prompt, image, device) input_tokens = int(inputs["input_ids"].shape[1]) proc = psutil.Process(os.getpid()) with MemoryMonitor(proc) as mon: t0 = time.perf_counter() with torch.inference_mode(): out = model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=False, use_cache=True, ) t1 = time.perf_counter() latency = t1 - t0 generated_tokens = int(out.shape[1] - input_tokens) tps = float(generated_tokens / latency) if latency > 0 else 0.0 decoded = processor.batch_decode(out[:, input_tokens:], skip_special_tokens=True) preview = (decoded[0] if decoded else "").strip().replace("\n", " ")[:220] return CaseResult( name=case_name, input_tokens=input_tokens, generated_tokens=generated_tokens, latency_s=round(latency, 3), tokens_per_s=round(tps, 3), peak_rss_mb=round(mon.max_rss / (1024 * 1024), 2), output_preview=preview, ) def main(): process = psutil.Process(os.getpid()) rss_start_mb = process.memory_info().rss / (1024 * 1024) device, dtype = pick_device() print(f"Device: {device}, dtype: {dtype}") t0 = time.perf_counter() processor = AutoProcessor.from_pretrained(MODEL_ID) model = AutoModelForImageTextToText.from_pretrained( MODEL_ID, torch_dtype=dtype, low_cpu_mem_usage=True, attn_implementation="eager", ).to(device) model.eval() t1 = time.perf_counter() load_time_s = t1 - t0 rss_after_load_mb = process.memory_info().rss / (1024 * 1024) print("Downloading sample image...") img_bytes = requests.get(SAMPLE_IMAGE_URL, timeout=30).content image = Image.open(__import__("io").BytesIO(img_bytes)).convert("RGB") # Warm-up _ = run_case( model, processor, device, "warmup", "Describe this image briefly.", image, max_new_tokens=16, ) cases = [ ("text_only_short", "You are an agri assistant. Give 3 tips for identifying early leaf blight.", None, 64), ( "image_short", "What do you see in this image? Mention crop/plant clues if visible.", image, 64, ), ( "image_long", "Analyze this image for agriculture relevance. Return: 1) likely object/plant, 2) possible health indicators, 3) recommended next observation steps, 4) confidence from 0-1.", image, 96, ), ] results = [] for name, text, img, max_new_tokens in cases: print(f"Running case: {name}") results.append(asdict(run_case(model, processor, device, name, text, img, max_new_tokens=max_new_tokens))) payload = { "model_id": MODEL_ID, "device": device, "dtype": str(dtype), "load_time_s": round(load_time_s, 3), "rss_start_mb": round(rss_start_mb, 2), "rss_after_load_mb": round(rss_after_load_mb, 2), "model_num_parameters": int(model.num_parameters()), "transformers_version": __import__("transformers").__version__, "torch_version": torch.__version__, "cases": results, } with open(OUT_JSON, "w", encoding="utf-8") as f: json.dump(payload, f, indent=2) print(f"Saved results to {OUT_JSON}") print(json.dumps(payload, indent=2)) if __name__ == "__main__": main()