| import json |
| import os |
| import time |
| import threading |
| from dataclasses import dataclass, asdict |
|
|
| import psutil |
| import requests |
| import torch |
| from PIL import Image |
| from transformers import AutoModelForImageTextToText, AutoProcessor |
|
|
|
|
| MODEL_ID = "Dharunkumar9/SmolVLM-256M-Instruct-Agri" |
| OUT_JSON = os.path.join(os.path.dirname(__file__), "benchmark_results.json") |
| SAMPLE_IMAGE_URL = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/cats.png" |
|
|
|
|
| @dataclass |
| class CaseResult: |
| name: str |
| input_tokens: int |
| generated_tokens: int |
| latency_s: float |
| tokens_per_s: float |
| peak_rss_mb: float |
| output_preview: str |
|
|
|
|
| class MemoryMonitor: |
| def __init__(self, process: psutil.Process, interval_s: float = 0.01): |
| self.process = process |
| self.interval_s = interval_s |
| self._running = False |
| self._thread = None |
| self.max_rss = 0 |
|
|
| def _run(self): |
| while self._running: |
| rss = self.process.memory_info().rss |
| if rss > self.max_rss: |
| self.max_rss = rss |
| time.sleep(self.interval_s) |
|
|
| def __enter__(self): |
| self._running = True |
| self.max_rss = self.process.memory_info().rss |
| self._thread = threading.Thread(target=self._run, daemon=True) |
| self._thread.start() |
| return self |
|
|
| def __exit__(self, exc_type, exc, tb): |
| self._running = False |
| if self._thread is not None: |
| self._thread.join(timeout=1) |
|
|
|
|
| def pick_device(): |
| if torch.backends.mps.is_available(): |
| return "mps", torch.float16 |
| if torch.cuda.is_available(): |
| return "cuda", torch.bfloat16 |
| return "cpu", torch.float32 |
|
|
|
|
| def make_prompt(processor, text: str, with_image: bool): |
| if with_image: |
| messages = [ |
| { |
| "role": "user", |
| "content": [ |
| {"type": "image"}, |
| {"type": "text", "text": text}, |
| ], |
| } |
| ] |
| else: |
| messages = [{"role": "user", "content": [{"type": "text", "text": text}]}] |
|
|
| return processor.apply_chat_template(messages, add_generation_prompt=True) |
|
|
|
|
| def prepare_inputs(processor, prompt: str, image: Image.Image | None, device: str): |
| kwargs = {"text": prompt, "return_tensors": "pt"} |
| if image is not None: |
| kwargs["images"] = [image] |
| inputs = processor(**kwargs) |
| return {k: (v.to(device) if torch.is_tensor(v) else v) for k, v in inputs.items()} |
|
|
|
|
| def run_case(model, processor, device: str, case_name: str, text: str, image: Image.Image | None, max_new_tokens: int = 64): |
| prompt = make_prompt(processor, text, with_image=image is not None) |
| inputs = prepare_inputs(processor, prompt, image, device) |
| input_tokens = int(inputs["input_ids"].shape[1]) |
|
|
| proc = psutil.Process(os.getpid()) |
| with MemoryMonitor(proc) as mon: |
| t0 = time.perf_counter() |
| with torch.inference_mode(): |
| out = model.generate( |
| **inputs, |
| max_new_tokens=max_new_tokens, |
| do_sample=False, |
| use_cache=True, |
| ) |
| t1 = time.perf_counter() |
|
|
| latency = t1 - t0 |
| generated_tokens = int(out.shape[1] - input_tokens) |
| tps = float(generated_tokens / latency) if latency > 0 else 0.0 |
|
|
| decoded = processor.batch_decode(out[:, input_tokens:], skip_special_tokens=True) |
| preview = (decoded[0] if decoded else "").strip().replace("\n", " ")[:220] |
|
|
| return CaseResult( |
| name=case_name, |
| input_tokens=input_tokens, |
| generated_tokens=generated_tokens, |
| latency_s=round(latency, 3), |
| tokens_per_s=round(tps, 3), |
| peak_rss_mb=round(mon.max_rss / (1024 * 1024), 2), |
| output_preview=preview, |
| ) |
|
|
|
|
| def main(): |
| process = psutil.Process(os.getpid()) |
| rss_start_mb = process.memory_info().rss / (1024 * 1024) |
|
|
| device, dtype = pick_device() |
| print(f"Device: {device}, dtype: {dtype}") |
|
|
| t0 = time.perf_counter() |
| processor = AutoProcessor.from_pretrained(MODEL_ID) |
| model = AutoModelForImageTextToText.from_pretrained( |
| MODEL_ID, |
| torch_dtype=dtype, |
| low_cpu_mem_usage=True, |
| attn_implementation="eager", |
| ).to(device) |
| model.eval() |
| t1 = time.perf_counter() |
|
|
| load_time_s = t1 - t0 |
| rss_after_load_mb = process.memory_info().rss / (1024 * 1024) |
|
|
| print("Downloading sample image...") |
| img_bytes = requests.get(SAMPLE_IMAGE_URL, timeout=30).content |
| image = Image.open(__import__("io").BytesIO(img_bytes)).convert("RGB") |
|
|
| |
| _ = run_case( |
| model, |
| processor, |
| device, |
| "warmup", |
| "Describe this image briefly.", |
| image, |
| max_new_tokens=16, |
| ) |
|
|
| cases = [ |
| ("text_only_short", "You are an agri assistant. Give 3 tips for identifying early leaf blight.", None, 64), |
| ( |
| "image_short", |
| "What do you see in this image? Mention crop/plant clues if visible.", |
| image, |
| 64, |
| ), |
| ( |
| "image_long", |
| "Analyze this image for agriculture relevance. Return: 1) likely object/plant, 2) possible health indicators, 3) recommended next observation steps, 4) confidence from 0-1.", |
| image, |
| 96, |
| ), |
| ] |
|
|
| results = [] |
| for name, text, img, max_new_tokens in cases: |
| print(f"Running case: {name}") |
| results.append(asdict(run_case(model, processor, device, name, text, img, max_new_tokens=max_new_tokens))) |
|
|
| payload = { |
| "model_id": MODEL_ID, |
| "device": device, |
| "dtype": str(dtype), |
| "load_time_s": round(load_time_s, 3), |
| "rss_start_mb": round(rss_start_mb, 2), |
| "rss_after_load_mb": round(rss_after_load_mb, 2), |
| "model_num_parameters": int(model.num_parameters()), |
| "transformers_version": __import__("transformers").__version__, |
| "torch_version": torch.__version__, |
| "cases": results, |
| } |
|
|
| with open(OUT_JSON, "w", encoding="utf-8") as f: |
| json.dump(payload, f, indent=2) |
|
|
| print(f"Saved results to {OUT_JSON}") |
| print(json.dumps(payload, indent=2)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|