{ "cells": [ { "cell_type": "code", "execution_count": 60, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: transformers in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (5.5.0)\n", "Requirement already satisfied: torch in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (2.11.0)\n", "Requirement already satisfied: pillow in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (12.1.1)\n", "Requirement already satisfied: openai in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (2.30.0)\n", "Requirement already satisfied: huggingface-hub<2.0,>=1.5.0 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from transformers) (1.6.0)\n", "Requirement already satisfied: numpy>=1.17 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from transformers) (2.4.2)\n", "Requirement already satisfied: packaging>=20.0 in c:\\users\\kathe\\appdata\\roaming\\python\\python314\\site-packages (from transformers) (26.0)\n", "Requirement already satisfied: pyyaml>=5.1 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from transformers) (6.0.3)\n", "Requirement already satisfied: regex>=2025.10.22 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from transformers) (2026.4.4)\n", "Requirement already satisfied: tokenizers<=0.23.0,>=0.22.0 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from transformers) (0.22.2)\n", "Requirement already satisfied: typer in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from transformers) (0.24.1)\n", "Requirement already satisfied: safetensors>=0.4.3 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from transformers) (0.7.0)\n", "Requirement already satisfied: tqdm>=4.27 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from transformers) (4.67.3)\n", "Requirement already satisfied: filelock>=3.10.0 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from huggingface-hub<2.0,>=1.5.0->transformers) (3.25.0)\n", "Requirement already satisfied: fsspec>=2023.5.0 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from huggingface-hub<2.0,>=1.5.0->transformers) (2026.2.0)\n", "Requirement already satisfied: hf-xet<2.0.0,>=1.3.2 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from huggingface-hub<2.0,>=1.5.0->transformers) (1.3.2)\n", "Requirement already satisfied: httpx<1,>=0.23.0 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from huggingface-hub<2.0,>=1.5.0->transformers) (0.28.1)\n", "Requirement already satisfied: typing-extensions>=4.1.0 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from huggingface-hub<2.0,>=1.5.0->transformers) (4.15.0)\n", "Requirement already satisfied: anyio in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from httpx<1,>=0.23.0->huggingface-hub<2.0,>=1.5.0->transformers) (4.12.1)\n", "Requirement already satisfied: certifi in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from httpx<1,>=0.23.0->huggingface-hub<2.0,>=1.5.0->transformers) (2026.2.25)\n", "Requirement already satisfied: httpcore==1.* in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from httpx<1,>=0.23.0->huggingface-hub<2.0,>=1.5.0->transformers) (1.0.9)\n", "Requirement already satisfied: idna in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from httpx<1,>=0.23.0->huggingface-hub<2.0,>=1.5.0->transformers) (3.11)\n", "Requirement already satisfied: h11>=0.16 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from httpcore==1.*->httpx<1,>=0.23.0->huggingface-hub<2.0,>=1.5.0->transformers) (0.16.0)\n", "Requirement already satisfied: setuptools<82 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from torch) (81.0.0)\n", "Requirement already satisfied: sympy>=1.13.3 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from torch) (1.14.0)\n", "Requirement already satisfied: networkx>=2.5.1 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from torch) (3.6.1)\n", "Requirement already satisfied: jinja2 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from torch) (3.1.6)\n", "Requirement already satisfied: distro<2,>=1.7.0 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from openai) (1.9.0)\n", "Requirement already satisfied: jiter<1,>=0.10.0 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from openai) (0.13.0)\n", "Requirement already satisfied: pydantic<3,>=1.9.0 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from openai) (2.12.5)\n", "Requirement already satisfied: sniffio in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from openai) (1.3.1)\n", "Requirement already satisfied: annotated-types>=0.6.0 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from pydantic<3,>=1.9.0->openai) (0.7.0)\n", "Requirement already satisfied: pydantic-core==2.41.5 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from pydantic<3,>=1.9.0->openai) (2.41.5)\n", "Requirement already satisfied: typing-inspection>=0.4.2 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from pydantic<3,>=1.9.0->openai) (0.4.2)\n", "Requirement already satisfied: mpmath<1.4,>=1.1.0 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from sympy>=1.13.3->torch) (1.3.0)\n", "Requirement already satisfied: colorama in c:\\users\\kathe\\appdata\\roaming\\python\\python314\\site-packages (from tqdm>=4.27->transformers) (0.4.6)\n", "Requirement already satisfied: MarkupSafe>=2.0 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from jinja2->torch) (3.0.3)\n", "Requirement already satisfied: click>=8.2.1 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from typer->transformers) (8.3.1)\n", "Requirement already satisfied: shellingham>=1.3.0 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from typer->transformers) (1.5.4)\n", "Requirement already satisfied: rich>=12.3.0 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from typer->transformers) (14.3.3)\n", "Requirement already satisfied: annotated-doc>=0.0.2 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from typer->transformers) (0.0.4)\n", "Requirement already satisfied: markdown-it-py>=2.2.0 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from rich>=12.3.0->typer->transformers) (4.0.0)\n", "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in c:\\users\\kathe\\appdata\\roaming\\python\\python314\\site-packages (from rich>=12.3.0->typer->transformers) (2.19.2)\n", "Requirement already satisfied: mdurl~=0.1 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from markdown-it-py>=2.2.0->rich>=12.3.0->typer->transformers) (0.1.2)\n", "Note: you may need to restart the kernel to use updated packages.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n", "[notice] A new release of pip is available: 25.3 -> 26.0.1\n", "[notice] To update, run: python.exe -m pip install --upgrade pip\n" ] } ], "source": [ "%pip install transformers torch pillow openai\n", "from transformers import pipeline\n", "from PIL import Image\n", "import os\n", "import pandas as pd" ] }, { "cell_type": "code", "execution_count": 63, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model path exists: True\n", "Image folder exists: True\n", "Images: ['Cheetah_032.jpg', 'Leopard_001.jpg', 'Lion_003.jpg', 'Puma_001.jpg', 'Tiger_001.jpg']\n" ] } ], "source": [ "MODEL_PATH = \"./cat-vit\"\n", "IMAGE_FOLDER = \"./Cats-classification-app/example_images\"\n", "\n", "labels = [\"cheetah\", \"leopard\", \"lion\", \"puma\", \"tiger\"]\n", "clip_labels = [f\"a photo of a {label}\" for label in labels]\n", "\n", "print(\"Model path exists:\", os.path.exists(MODEL_PATH))\n", "print(\"Image folder exists:\", os.path.exists(IMAGE_FOLDER))\n", "print(\"Images:\", [f for f in os.listdir(IMAGE_FOLDER) if f.lower().endswith((\".jpg\", \".jpeg\", \".png\"))])" ] }, { "cell_type": "code", "execution_count": 64, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "1bf87a05bbc346c9b3f30eb950c1f3a5", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading weights: 0%| | 0/200 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "fad1554b05bf40d7b31480f8daa8ad35", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading weights: 0%| | 0/398 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[1mCLIPModel LOAD REPORT\u001b[0m from: openai/clip-vit-base-patch32\n", "Key | Status | | \n", "-------------------------------------+------------+--+-\n", "text_model.embeddings.position_ids | UNEXPECTED | | \n", "vision_model.embeddings.position_ids | UNEXPECTED | | \n", "\n", "Notes:\n", "- UNEXPECTED:\tcan be ignored when loading from different task/architecture; not ok if you expect identical arch.\n" ] } ], "source": [ "custom_model = pipeline(\"image-classification\", model=MODEL_PATH)\n", "\n", "clip_model = pipeline(\n", " \"zero-shot-image-classification\",\n", " model=\"openai/clip-vit-base-patch32\"\n", ")" ] }, { "cell_type": "code", "execution_count": 65, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "bc19664791384aceb2502dfe76b5dd1d", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading weights: 0%| | 0/398 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[1mCLIPModel LOAD REPORT\u001b[0m from: openai/clip-vit-base-patch32\n", "Key | Status | | \n", "-------------------------------------+------------+--+-\n", "text_model.embeddings.position_ids | UNEXPECTED | | \n", "vision_model.embeddings.position_ids | UNEXPECTED | | \n", "\n", "Notes:\n", "- UNEXPECTED:\tcan be ignored when loading from different task/architecture; not ok if you expect identical arch.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CLIP model loaded!\n" ] } ], "source": [ "clip_model = pipeline(\n", " \"zero-shot-image-classification\",\n", " model=\"openai/clip-vit-base-patch32\"\n", ")\n", "\n", "print(\"CLIP model loaded!\")" ] }, { "cell_type": "code", "execution_count": 66, "metadata": {}, "outputs": [], "source": [ "def get_true_label(filename):\n", " name = filename.lower()\n", " \n", " if name.startswith(\"cheetah\"):\n", " return \"cheetah\"\n", " elif name.startswith(\"leopard\"):\n", " return \"leopard\"\n", " elif name.startswith(\"lion\"):\n", " return \"lion\"\n", " elif name.startswith(\"puma\"):\n", " return \"puma\"\n", " elif name.startswith(\"tiger\"):\n", " return \"tiger\"\n", " else:\n", " return \"unknown\"" ] }, { "cell_type": "code", "execution_count": 67, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Found images: ['Cheetah_032.jpg', 'Leopard_001.jpg', 'Lion_003.jpg', 'Puma_001.jpg', 'Tiger_001.jpg']\n", "results length: 5\n", " image true_label custom_pred custom_score clip_pred clip_score \\\n", "0 Cheetah_032.jpg cheetah cheetah 0.5264 cheetah 0.8319 \n", "1 Leopard_001.jpg leopard leopard 0.5127 leopard 0.9232 \n", "2 Lion_003.jpg lion lion 0.5408 lion 0.9949 \n", "3 Puma_001.jpg puma puma 0.6112 puma 0.9986 \n", "4 Tiger_001.jpg tiger tiger 0.6976 tiger 0.9892 \n", "\n", " custom_correct clip_correct \n", "0 True True \n", "1 True True \n", "2 True True \n", "3 True True \n", "4 True True \n", "columns: ['image', 'true_label', 'custom_pred', 'custom_score', 'clip_pred', 'clip_score', 'custom_correct', 'clip_correct']\n" ] } ], "source": [ "results = []\n", "\n", "id2label = {\n", " 0: \"cheetah\",\n", " 1: \"leopard\",\n", " 2: \"lion\",\n", " 3: \"puma\",\n", " 4: \"tiger\"\n", "}\n", "\n", "image_files = sorted([\n", " f for f in os.listdir(IMAGE_FOLDER)\n", " if f.lower().endswith((\".jpg\", \".jpeg\", \".png\"))\n", "])\n", "\n", "print(\"Found images:\", image_files)\n", "\n", "for img_file in image_files:\n", " image_path = os.path.join(IMAGE_FOLDER, img_file)\n", " image = Image.open(image_path).convert(\"RGB\")\n", " true_label = get_true_label(img_file)\n", "\n", " custom_result = custom_model(image)[0]\n", " raw_custom_label = custom_result[\"label\"]\n", " custom_score = float(custom_result[\"score\"])\n", "\n", " if raw_custom_label.startswith(\"LABEL_\"):\n", " label_id = int(raw_custom_label.split(\"_\")[1])\n", " custom_pred = id2label[label_id]\n", " else:\n", " custom_pred = raw_custom_label.lower()\n", "\n", " clip_result = clip_model(image, candidate_labels=clip_labels)[0]\n", " clip_pred = clip_result[\"label\"].replace(\"a photo of a \", \"\").lower()\n", " clip_score = float(clip_result[\"score\"])\n", "\n", " results.append({\n", " \"image\": img_file,\n", " \"true_label\": true_label,\n", " \"custom_pred\": custom_pred,\n", " \"custom_score\": round(custom_score, 4),\n", " \"clip_pred\": clip_pred,\n", " \"clip_score\": round(clip_score, 4),\n", " \"custom_correct\": custom_pred == true_label,\n", " \"clip_correct\": clip_pred == true_label,\n", " })\n", "\n", "print(\"results length:\", len(results))\n", "\n", "df = pd.DataFrame(results)\n", "print(df)\n", "print(\"columns:\", df.columns.tolist())" ] }, { "cell_type": "code", "execution_count": 68, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Custom accuracy: 1.0\n", "CLIP accuracy: 1.0\n" ] } ], "source": [ "custom_accuracy = df[\"custom_correct\"].mean()\n", "clip_accuracy = df[\"clip_correct\"].mean()\n", "\n", "print(\"Custom accuracy:\", round(custom_accuracy, 4))\n", "print(\"CLIP accuracy:\", round(clip_accuracy, 4))" ] }, { "cell_type": "code", "execution_count": 69, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Saved to comparison_results.csv\n" ] } ], "source": [ "df.to_csv(\"comparison_results.csv\", index=False)\n", "print(\"Saved to comparison_results.csv\")" ] }, { "cell_type": "code", "execution_count": 70, "metadata": {}, "outputs": [], "source": [ "MODEL_PATH = \"./cat-vit\"\n", "IMAGE_FOLDER = \"./Cats-classification-app/example_images\"" ] }, { "cell_type": "code", "execution_count": 71, "metadata": {}, "outputs": [], "source": [ "import os\n", "from openai import OpenAI\n", "\n", "os.environ[\"OPENAI_API_KEY\"] = \"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\"\n", "\n", "client = OpenAI()" ] }, { "cell_type": "code", "execution_count": 72, "metadata": {}, "outputs": [], "source": [ "def predict_openai_label(image_path):\n", " with open(image_path, \"rb\") as image_file:\n", " image_base64 = base64.b64encode(image_file.read()).decode(\"utf-8\")\n", "\n", " response = client.responses.create(\n", " model=\"gpt-4.1-mini\",\n", " input=[\n", " {\n", " \"role\": \"user\",\n", " \"content\": [\n", " {\n", " \"type\": \"input_text\",\n", " \"text\": \"Classify this image as exactly one of these labels: cheetah, leopard, lion, puma, tiger. Return only one label in lowercase.\"\n", " },\n", " {\n", " \"type\": \"input_image\",\n", " \"image_url\": f\"data:image/jpeg;base64,{image_base64}\"\n", " }\n", " ]\n", " }\n", " ]\n", " )\n", "\n", " return response.output_text.strip().lower()" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [], "source": [ "def predict_openai_label(image_path):\n", " with open(image_path, \"rb\") as image_file:\n", " image_base64 = base64.b64encode(image_file.read()).decode(\"utf-8\")\n", "\n", " response = client.responses.create(\n", " model=\"gpt-4.1-mini\",\n", " input=[\n", " {\n", " \"role\": \"user\",\n", " \"content\": [\n", " {\n", " \"type\": \"input_text\",\n", " \"text\": \"Classify this image as exactly one of these labels: cheetah, leopard, lion, puma, tiger. Return only one label in lowercase.\"\n", " },\n", " {\n", " \"type\": \"input_image\",\n", " \"image_url\": f\"data:image/jpeg;base64,{image_base64}\"\n", " }\n", " ]\n", " }\n", " ]\n", " )\n", "\n", " return response.output_text.strip().lower()" ] }, { "cell_type": "code", "execution_count": 73, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
| \n", " | image | \n", "true_label | \n", "custom_pred | \n", "custom_score | \n", "clip_pred | \n", "clip_score | \n", "openai_pred | \n", "custom_correct | \n", "clip_correct | \n", "openai_correct | \n", "
|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n", "Cheetah_032.jpg | \n", "cheetah | \n", "cheetah | \n", "0.5264 | \n", "cheetah | \n", "0.8319 | \n", "ERROR: name 'base64' is not defined | \n", "True | \n", "True | \n", "False | \n", "
| 1 | \n", "Leopard_001.jpg | \n", "leopard | \n", "leopard | \n", "0.5127 | \n", "leopard | \n", "0.9232 | \n", "ERROR: name 'base64' is not defined | \n", "True | \n", "True | \n", "False | \n", "
| 2 | \n", "Lion_003.jpg | \n", "lion | \n", "lion | \n", "0.5408 | \n", "lion | \n", "0.9949 | \n", "ERROR: name 'base64' is not defined | \n", "True | \n", "True | \n", "False | \n", "
| 3 | \n", "Puma_001.jpg | \n", "puma | \n", "puma | \n", "0.6112 | \n", "puma | \n", "0.9986 | \n", "ERROR: name 'base64' is not defined | \n", "True | \n", "True | \n", "False | \n", "
| 4 | \n", "Tiger_001.jpg | \n", "tiger | \n", "tiger | \n", "0.6976 | \n", "tiger | \n", "0.9892 | \n", "ERROR: name 'base64' is not defined | \n", "True | \n", "True | \n", "False | \n", "