adisaljusi commited on
Commit
3e6a4eb
·
1 Parent(s): 21b4eed

Refactor code structure for improved readability

Browse files
.gitattributes CHANGED
@@ -32,4 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ requirements.md
2
+ __pycache__/
3
+ *.pyc
4
+ .env
5
+ *.pt
6
+ *.pth
7
+ checkpoint-*/
8
+ cifar10-vit/
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: Computer Vision Classification Model Comparison
3
- emoji: 📊
4
  colorFrom: purple
5
  colorTo: gray
6
  sdk: gradio
@@ -10,4 +10,58 @@ pinned: false
10
  short_description: 'Block 2 '
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: Computer Vision Classification Model Comparison
3
+ emoji: "\U0001F4CA"
4
  colorFrom: purple
5
  colorTo: gray
6
  sdk: gradio
 
10
  short_description: 'Block 2 '
11
  ---
12
 
13
+ # CIFAR-10 Image Classification Model Comparison
14
+
15
+ Compare three classification approaches on CIFAR-10 images:
16
+
17
+ - **Fine-tuned ViT** ([adisaljusi/vit-base-cifar10](https://huggingface.co/adisaljusi/vit-base-cifar10)) — transfer learning model trained on CIFAR-10
18
+ - **CLIP Zero-Shot** (`openai/clip-vit-large-patch14`) — open-source zero-shot classification
19
+ - **OpenAI GPT-4.1-mini** — closed-source vision model via API
20
+
21
+ ## Dataset
22
+
23
+ **CIFAR-10** — 60,000 32x32 color images in 10 classes (6,000 images per class):
24
+
25
+ | Split | Images |
26
+ |-------|--------|
27
+ | Train | 50,000 |
28
+ | Test | 10,000 |
29
+
30
+ **Classes:** airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck
31
+
32
+ Source: [Hugging Face `uoft-cs/cifar10`](https://huggingface.co/datasets/uoft-cs/cifar10)
33
+
34
+ ## Preprocessing
35
+
36
+ - Resize from 32x32 to 224x224 (ViT input size)
37
+ - Normalize pixel values with mean=0.5, std=0.5 per channel
38
+ - Convert all images to RGB
39
+
40
+ Applied using `AutoImageProcessor` from `google/vit-base-patch16-224`.
41
+
42
+ ## Model & Evaluation
43
+
44
+ **Base model:** [google/vit-base-patch16-224](https://huggingface.co/google/vit-base-patch16-224)
45
+
46
+ **Transfer learning approach:** All layers frozen except the final classification head (10 outputs for CIFAR-10 classes). Only 7,690 of 85.8M parameters are trainable.
47
+
48
+ **Training config:** 5 epochs, batch size 16, learning rate 3e-4, AdamW optimizer.
49
+
50
+ ### Training Results
51
+
52
+ | Epoch | Training Loss | Validation Loss | Accuracy |
53
+ |------:|--------------:|----------------:|---------:|
54
+ | _To be filled after training_ | | | |
55
+
56
+ ## Links
57
+
58
+ - **Model:** [adisaljusi/vit-base-cifar10](https://huggingface.co/adisaljusi/vit-base-cifar10)
59
+ - **App:** [adisaljusi/computer-vision-classification-model-comparison](https://huggingface.co/spaces/adisaljusi/computer-vision-classification-model-comparison)
60
+
61
+ ## Comparison Results
62
+
63
+ Results on example images comparing all three models:
64
+
65
+ | Image | True Class | ViT Top-1 (score) | CLIP Top-1 (score) | OpenAI (label, confidence) |
66
+ |-------|-----------|-------------------|-------------------|---------------------------|
67
+ | _To be filled after running the app_ | | | | |
app.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import json
3
+ import os
4
+
5
+ import gradio as gr
6
+ from dotenv import load_dotenv
7
+ from openai import OpenAI
8
+ from transformers import pipeline
9
+
10
+ load_dotenv()
11
+
12
+ OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-4.1-mini")
13
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
14
+ openai_client = OpenAI(api_key=OPENAI_API_KEY) if OPENAI_API_KEY else None
15
+
16
+ # Load models
17
+ vit_classifier = pipeline("image-classification", model="adisaljusi/vit-base-cifar10")
18
+ clip_detector = pipeline(
19
+ model="openai/clip-vit-large-patch14",
20
+ task="zero-shot-image-classification",
21
+ )
22
+
23
+ labels_cifar10 = [
24
+ "airplane", "automobile", "bird", "cat", "deer",
25
+ "dog", "frog", "horse", "ship", "truck",
26
+ ]
27
+
28
+
29
+ def encode_image(image_path):
30
+ with open(image_path, "rb") as image_file:
31
+ return base64.b64encode(image_file.read()).decode("utf-8")
32
+
33
+
34
+ def classify_with_openai(image_path):
35
+ if openai_client is None:
36
+ return {
37
+ "error": "Missing OPENAI_API_KEY. Add it to your environment or .env file to enable OpenAI classification."
38
+ }
39
+
40
+ prompt = (
41
+ "Classify the object in this image. Choose the best matching label from this list: "
42
+ f"{', '.join(labels_cifar10)}. "
43
+ "Return valid JSON with exactly these keys: "
44
+ "label, confidence, reasoning. "
45
+ "The confidence must be a number between 0 and 1."
46
+ )
47
+
48
+ base64_image = encode_image(image_path)
49
+ response = openai_client.responses.create(
50
+ model=OPENAI_MODEL,
51
+ input=[
52
+ {
53
+ "role": "user",
54
+ "content": [
55
+ {"type": "input_text", "text": prompt},
56
+ {
57
+ "type": "input_image",
58
+ "image_url": f"data:image/jpeg;base64,{base64_image}",
59
+ },
60
+ ],
61
+ }
62
+ ],
63
+ )
64
+
65
+ try:
66
+ parsed_response = json.loads(response.output_text)
67
+ except json.JSONDecodeError:
68
+ parsed_response = {
69
+ "raw_response": response.output_text,
70
+ "warning": "OpenAI response was not valid JSON.",
71
+ }
72
+
73
+ return parsed_response
74
+
75
+
76
+ def classify_image(image):
77
+ vit_results = vit_classifier(image)
78
+ vit_output = {result["label"]: result["score"] for result in vit_results}
79
+
80
+ clip_results = clip_detector(image, candidate_labels=labels_cifar10)
81
+ clip_output = {result["label"]: result["score"] for result in clip_results}
82
+
83
+ openai_output = classify_with_openai(image)
84
+
85
+ return {
86
+ "ViT Classification": vit_output,
87
+ "CLIP Zero-Shot Classification": clip_output,
88
+ "OpenAI Vision Classification": openai_output,
89
+ }
90
+
91
+
92
+ example_images = [
93
+ ["example_images/airplane.jpg"],
94
+ ["example_images/automobile.jpg"],
95
+ ["example_images/cat.jpg"],
96
+ ["example_images/dog.jpg"],
97
+ ["example_images/horse.jpg"],
98
+ ["example_images/ship.jpg"],
99
+ ]
100
+
101
+ iface = gr.Interface(
102
+ fn=classify_image,
103
+ inputs=gr.Image(type="filepath"),
104
+ outputs=gr.JSON(),
105
+ title="CIFAR-10 Classification Comparison",
106
+ description=(
107
+ "Upload an image and compare classification results from three models: "
108
+ "a fine-tuned ViT model, a zero-shot CLIP model, and OpenAI GPT-4.1-mini vision."
109
+ ),
110
+ examples=example_images,
111
+ )
112
+
113
+ iface.launch()
example_images/airplane.jpg ADDED
example_images/automobile.jpg ADDED
example_images/cat.jpg ADDED
example_images/dog.jpg ADDED
example_images/horse.jpg ADDED
example_images/ship.jpg ADDED
requirements-dev.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ -r requirements.txt
2
+ ipykernel
3
+ datasets
4
+ evaluate
5
+ matplotlib
6
+ numpy
7
+ huggingface-hub
8
+ ipywidgets
requirements.txt ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ annotated-doc==0.0.4
2
+ # via typer
3
+ annotated-types==0.7.0
4
+ # via pydantic
5
+ anyio==4.13.0
6
+ # via
7
+ # httpx
8
+ # openai
9
+ certifi==2026.2.25
10
+ # via
11
+ # httpcore
12
+ # httpx
13
+ click==8.3.2
14
+ # via typer
15
+ distro==1.9.0
16
+ # via openai
17
+ filelock==3.25.2
18
+ # via
19
+ # huggingface-hub
20
+ # torch
21
+ fsspec==2026.3.0
22
+ # via
23
+ # huggingface-hub
24
+ # torch
25
+ h11==0.16.0
26
+ # via httpcore
27
+ hf-xet==1.4.3
28
+ # via huggingface-hub
29
+ httpcore==1.0.9
30
+ # via httpx
31
+ httpx==0.28.1
32
+ # via
33
+ # huggingface-hub
34
+ # openai
35
+ huggingface-hub==1.9.0
36
+ # via
37
+ # tokenizers
38
+ # transformers
39
+ idna==3.11
40
+ # via
41
+ # anyio
42
+ # httpx
43
+ jinja2==3.1.6
44
+ # via torch
45
+ jiter==0.13.0
46
+ # via openai
47
+ markdown-it-py==4.0.0
48
+ # via rich
49
+ markupsafe==3.0.3
50
+ # via jinja2
51
+ mdurl==0.1.2
52
+ # via markdown-it-py
53
+ mpmath==1.3.0
54
+ # via sympy
55
+ networkx==3.6.1
56
+ # via torch
57
+ numpy==2.4.4
58
+ # via transformers
59
+ openai==2.30.0
60
+ # via -r requirements.txt
61
+ packaging==26.0
62
+ # via
63
+ # huggingface-hub
64
+ # transformers
65
+ pydantic==2.12.5
66
+ # via openai
67
+ pydantic-core==2.41.5
68
+ # via pydantic
69
+ pygments==2.20.0
70
+ # via rich
71
+ python-dotenv==1.2.2
72
+ # via -r requirements.txt
73
+ pyyaml==6.0.3
74
+ # via
75
+ # huggingface-hub
76
+ # transformers
77
+ regex==2026.4.4
78
+ # via transformers
79
+ rich==14.3.3
80
+ # via typer
81
+ safetensors==0.7.0
82
+ # via transformers
83
+ setuptools==81.0.0
84
+ # via torch
85
+ shellingham==1.5.4
86
+ # via typer
87
+ sniffio==1.3.1
88
+ # via openai
89
+ sympy==1.14.0
90
+ # via torch
91
+ tokenizers==0.22.2
92
+ # via transformers
93
+ torch==2.11.0
94
+ # via -r requirements.txt
95
+ tqdm==4.67.3
96
+ # via
97
+ # huggingface-hub
98
+ # openai
99
+ # transformers
100
+ transformers==5.5.0
101
+ # via -r requirements.txt
102
+ typer==0.24.1
103
+ # via
104
+ # huggingface-hub
105
+ # transformers
106
+ typing-extensions==4.15.0
107
+ # via
108
+ # huggingface-hub
109
+ # openai
110
+ # pydantic
111
+ # pydantic-core
112
+ # torch
113
+ # typing-inspection
114
+ typing-inspection==0.4.2
115
+ # via pydantic
train.ipynb ADDED
The diff for this file is too large to render. See raw diff