Nuno-Tome commited on
Commit
c5b2d17
·
1 Parent(s): 98c6d70

feat: run all models and display results in table instead of dropdown

Browse files
Files changed (1) hide show
  1. app.py +137 -141
app.py CHANGED
@@ -1,164 +1,160 @@
1
  import streamlit as st
2
  from transformers import pipeline
3
  from PIL import Image
 
4
 
5
- MODEL_1 = "google/vit-base-patch16-224"
6
  MIN_ACEPTABLE_SCORE = 0.1
7
  MAX_N_LABELS = 5
8
- MODEL_2 = "nateraw/vit-age-classifier"
9
  MODELS = [
10
- "-- General Image Classification --",
11
- "google/vit-base-patch16-224",
12
- "microsoft/resnet-50",
13
- "microsoft/resnet-18",
14
- "microsoft/resnet-34",
15
- "microsoft/resnet-101",
16
- "microsoft/resnet-152",
17
- "microsoft/swin-tiny-patch4-window7-224",
18
- "microsoft/swinv2-base-patch4-window16-256",
19
- "microsoft/beit-base-patch16-224-pt22k-ft22k",
20
- "facebook/convnext-large-224",
21
- "facebook/convnext-base-224-22k-1k",
22
- "facebook/convnext-tiny-224",
23
- "nvidia/mit-b0",
24
- "timm/resnet50.a1_in1k",
25
- "timm/tf_efficientnetv2_s.in21k",
26
- "timm/convnext_tiny.fb_in22k",
27
- "vit-base-patch16-224-in21k",
28
- "facebook/deit-base-distilled-patch16-224 << new >>",
29
- "WinKawaks/vit-tiny-patch16-224 << new >>",
30
-
31
- "-- Age Classification --",
32
- "nateraw/vit-age-classifier",
33
-
34
- "-- NSFW Detection --",
35
- "Falconsai/nsfw_image_detection",
36
- "LukeJacob2023/nsfw-image-detector",
37
- "carbon225/vit-base-patch16-224-hentai",
38
- "Marqo/nsfw-image-detection-384 << new >>",
39
-
40
- "-- Aesthetic/Art Classification --",
41
- "cafeai/cafe_aesthetic",
42
- "shadowlilac/aesthetic-shadow",
43
- "pixai-labs/pixai-tagger-v0.9 << new >>",
44
-
45
- "-- Face/Emotion Classification --",
46
- "trpakov/vit-face-expression",
47
- "RickyIG/emotion_face_image_classification",
48
- "rizvandwiki/gender-classification",
49
-
50
- "-- Food Classification --",
51
- "nateraw/food",
52
- "BinhQuocNguyen/food-recognition-model << new >>",
53
-
54
- "-- Medical/Dermatology --",
55
- "google/derm-foundation << new >>",
56
- "google/cxr-foundation << new >>",
57
- "Anwarkh1/Skin_Cancer-Image_Classification << new >>",
58
-
59
- "-- AI vs Human Detection --",
60
- "Ateeqq/ai-vs-human-image-detector << new >>",
61
- "umm-maybe/AI-image-detector << new >>",
62
-
63
- "-- Deepfake Detection --",
64
- "not-lain/deepfake",
65
-
66
- "-- Anime/Manga Classification --",
67
- #"Readidno/anime.mili << new >>", # Not working - missing model_type
68
-
69
- "-- Human Activity Recognition --",
70
- "DunnBC22/vit-base-patch16-224-in21k_Human_Activity_Recognition",
71
-
72
- "-- Clothing/Fashion --",
73
- "aalonso-developer/vit-base-patch16-224-in21k-clothing-classifier",
74
-
75
- "-- Real Estate --",
76
- "andupets/real-estate-image-classification",
77
-
78
- "-- Satellite/Remote Sensing --",
79
- "FatihC/swin-tiny-patch4-window7-224-finetuned-eurosat-watermark",
80
-
81
- "-- Car Classification --",
82
- "lamnt2008/car_brands_classification << new >>",
83
-
84
- "-- Document Classification --",
85
- "docling-project/DocumentFigureClassifier-v2.5 << new >>",
86
-
87
- "-- EfficientNet (timm) --",
88
- "timm/efficientnet_b0.ra_in1k << new >>",
89
- "timm/mobilenetv3_large_100.ra_in1k",
90
- "timm/mobilenetv3_small_100.lamb_in1k << new >>",
91
-
92
- "-- Experimental/Future --",
93
- "#q-future/one-align",
94
- ]
95
 
96
- def classify(image, model):
97
- model_name = model.replace(" << new >>", "")
98
- classifier = pipeline("image-classification", model=model_name)
99
- result= classifier(image)
100
- return result
101
 
102
- def save_result(result):
103
- st.write("In the future, this function will save the result in a database.")
 
 
104
 
105
- def print_result(result):
 
 
106
 
107
- comulative_discarded_score = 0
108
- for i in range(len(result)):
109
- if result[i]['score'] < MIN_ACEPTABLE_SCORE:
110
- comulative_discarded_score += result[i]['score']
111
- else:
112
- st.write(result[i]['label'])
113
- st.progress(result[i]['score'])
114
- st.write(result[i]['score'])
115
 
116
- st.write(f"comulative_discarded_score:")
117
- st.progress(comulative_discarded_score)
118
- st.write(comulative_discarded_score)
119
-
120
 
 
 
 
121
 
122
- def main():
123
- st.title("Image Classification")
124
- st.write("This is a simple web app to test and compare different image classifier models using Hugging Face's image-classification pipeline.")
125
- st.markdown(":white_check_mark: **:green[22 new models added!]** - Including Medical, AI vs Human detection, Anime classification and more.")
126
- st.write("From time to time more models will be added to the list. If you want to add a model, please open an issue on the GitHub repository.")
127
- st.write("If you like this project, please consider liking it or buying me a coffee. It will help me to keep working on this and other projects. Thank you!")
128
 
129
- # Buy me a Coffee Setup
130
- bmc_link = "https://www.buymeacoffee.com/nuno.tome"
131
- # image_url = "https://helloimjessa.files.wordpress.com/2021/06/bmc-button.png?w=150" # Image URL
132
- image_url = "https://i.giphy.com/RETzc1mj7HpZPuNf3e.webp" # Image URL
133
-
134
- image_size = "150px" # Image size
135
- #image_link_markdown = f"<img src='{image_url}' width='25%'>"
136
- image_link_markdown = f"[![Buy Me a Coffee]({image_url})]({bmc_link})"
137
 
138
- #image_link_markdown = f"[![Buy Me a Coffee]({image_url})]({bmc_link})" # Create a clickable image link
 
 
 
 
 
 
 
 
 
139
 
140
- st.markdown(image_link_markdown, unsafe_allow_html=True) # Display the image link
141
- # Buy me a Coffee Setup
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
- #st.markdown("<img src='https://helloimjessa.files.wordpress.com/2021/06/bmc-button.png?w=1024' width='15%'>", unsafe_allow_html=True)
 
 
 
144
 
145
  input_image = st.file_uploader("Upload Image")
146
- shosen_model = st.selectbox("Select the model to use", MODELS)
147
-
148
-
149
  if input_image is not None:
150
  image_to_classify = Image.open(input_image)
151
- st.image(image_to_classify, caption="Uploaded Image")
152
- if st.button("Classify"):
153
- image_to_classify = Image.open(input_image)
154
- classification_obj1 =[]
155
- #avable_models = st.selectbox
 
156
 
157
- classification_result = classify(image_to_classify, shosen_model)
158
- classification_obj1.append(classification_result)
159
- print_result(classification_result)
160
- save_result(classification_result)
161
-
162
-
163
- if __name__ == "__main__":
164
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  from transformers import pipeline
3
  from PIL import Image
4
+ import pandas as pd
5
 
 
6
  MIN_ACEPTABLE_SCORE = 0.1
7
  MAX_N_LABELS = 5
8
+
9
  MODELS = [
10
+ ("google/vit-base-patch16-224", "General Image Classification"),
11
+ ("microsoft/resnet-50", "General Image Classification"),
12
+ ("microsoft/resnet-18", "General Image Classification"),
13
+ ("microsoft/resnet-34", "General Image Classification"),
14
+ ("microsoft/resnet-101", "General Image Classification"),
15
+ ("microsoft/resnet-152", "General Image Classification"),
16
+ ("microsoft/swin-tiny-patch4-window7-224", "General Image Classification"),
17
+ ("microsoft/swinv2-base-patch4-window16-256", "General Image Classification"),
18
+ ("microsoft/beit-base-patch16-224-pt22k-ft22k", "General Image Classification"),
19
+ ("facebook/convnext-large-224", "General Image Classification"),
20
+ ("facebook/convnext-base-224-22k-1k", "General Image Classification"),
21
+ ("facebook/convnext-tiny-224", "General Image Classification"),
22
+ ("nvidia/mit-b0", "General Image Classification"),
23
+ ("timm/resnet50.a1_in1k", "General Image Classification"),
24
+ ("timm/tf_efficientnetv2_s.in21k", "General Image Classification"),
25
+ ("timm/convnext_tiny.fb_in22k", "General Image Classification"),
26
+ ("google/vit-base-patch16-224-in21k", "General Image Classification"),
27
+ ("facebook/deit-base-distilled-patch16-224", "General Image Classification"),
28
+ ("WinKawaks/vit-tiny-patch16-224", "General Image Classification"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
+ ("nateraw/vit-age-classifier", "Age Classification"),
 
 
 
 
31
 
32
+ ("Falconsai/nsfw_image_detection", "NSFW Detection"),
33
+ ("LukeJacob2023/nsfw-image-detector", "NSFW Detection"),
34
+ ("carbon225/vit-base-patch16-224-hentai", "NSFW Detection"),
35
+ ("Marqo/nsfw-image-detection-384", "NSFW Detection"),
36
 
37
+ ("cafeai/cafe_aesthetic", "Aesthetic/Art Classification"),
38
+ ("shadowlilac/aesthetic-shadow", "Aesthetic/Art Classification"),
39
+ ("pixai-labs/pixai-tagger-v0.9", "Aesthetic/Art Classification"),
40
 
41
+ ("trpakov/vit-face-expression", "Face/Emotion Classification"),
42
+ ("RickyIG/emotion_face_image_classification", "Face/Emotion Classification"),
43
+ ("rizvandwiki/gender-classification", "Face/Emotion Classification"),
 
 
 
 
 
44
 
45
+ ("nateraw/food", "Food Classification"),
46
+ ("BinhQuocNguyen/food-recognition-model", "Food Classification"),
 
 
47
 
48
+ ("google/derm-foundation", "Medical/Dermatology"),
49
+ ("google/cxr-foundation", "Medical/Dermatology"),
50
+ ("Anwarkh1/Skin_Cancer-Image_Classification", "Medical/Dermatology"),
51
 
52
+ ("Ateeqq/ai-vs-human-image-detector", "AI vs Human Detection"),
53
+ ("umm-maybe/AI-image-detector", "AI vs Human Detection"),
 
 
 
 
54
 
55
+ ("not-lain/deepfake", "Deepfake Detection"),
56
+
57
+ ("DunnBC22/vit-base-patch16-224-in21k_Human_Activity_Recognition", "Human Activity Recognition"),
58
+
59
+ ("aalonso-developer/vit-base-patch16-224-in21k-clothing-classifier", "Clothing/Fashion"),
60
+
61
+ ("andupets/real-estate-image-classification", "Real Estate"),
 
62
 
63
+ ("FatihC/swin-tiny-patch4-window7-224-finetuned-eurosat-watermark", "Satellite/Remote Sensing"),
64
+
65
+ ("lamnt2008/car_brands_classification", "Car Classification"),
66
+
67
+ ("docling-project/DocumentFigureClassifier-v2.5", "Document Classification"),
68
+
69
+ ("timm/efficientnet_b0.ra_in1k", "EfficientNet"),
70
+ ("timm/mobilenetv3_large_100.ra_in1k", "EfficientNet"),
71
+ ("timm/mobilenetv3_small_100.lamb_in1k", "EfficientNet"),
72
+ ]
73
 
74
+ def classify(image, model_name):
75
+ classifier = pipeline("image-classification", model=model_name)
76
+ result = classifier(image)
77
+ return result
78
+
79
+ def format_results(results):
80
+ labels = []
81
+ scores = []
82
+ for r in results[:MAX_N_LABELS]:
83
+ if r['score'] >= MIN_ACEPTABLE_SCORE:
84
+ labels.append(r['label'])
85
+ scores.append(f"{r['score']:.2%}")
86
+ return "<br>".join(labels), "<br>".join(scores)
87
+
88
+ def main():
89
+ st.title("Image Classification - Compare All Models")
90
+ st.write("This app runs ALL image classification models and displays results in a table.")
91
+ st.markdown(":white_check_mark: **:green[Run all models at once!]**")
92
 
93
+ bmc_link = "https://www.buymeacoffee.com/nuno.tome"
94
+ image_url = "https://i.giphy.com/RETzc1mj7HpZPuNf3e.webp"
95
+ image_link_markdown = f"[![Buy Me a Coffee]({image_url})]({bmc_link})"
96
+ st.markdown(image_link_markdown, unsafe_allow_html=True)
97
 
98
  input_image = st.file_uploader("Upload Image")
99
+
 
 
100
  if input_image is not None:
101
  image_to_classify = Image.open(input_image)
102
+ st.image(image_to_classify, caption="Uploaded Image", use_container_width=True)
103
+
104
+ if st.button("Run All Models", type="primary"):
105
+ results_data = []
106
+ progress_bar = st.progress(0)
107
+ status_text = st.empty()
108
 
109
+ for i, (model_name, category) in enumerate(MODELS):
110
+ status_text.text(f"Running model {i+1}/{len(MODELS)}: {model_name}")
111
+ try:
112
+ classification_result = classify(image_to_classify, model_name)
113
+ labels, scores = format_results(classification_result)
114
+ results_data.append({
115
+ "Model": model_name,
116
+ "Category": category,
117
+ "Top Labels": labels,
118
+ "Scores": scores
119
+ })
120
+ except Exception as e:
121
+ results_data.append({
122
+ "Model": model_name,
123
+ "Category": category,
124
+ "Top Labels": f"Error: {str(e)[:50]}",
125
+ "Scores": "-"
126
+ })
127
+ progress_bar.progress((i + 1) / len(MODELS))
128
+
129
+ status_text.text("Done!")
130
+
131
+ if results_data:
132
+ df = pd.DataFrame(results_data)
133
+ st.subheader(f"Results ({len(results_data)} models)")
134
+
135
+ st.markdown("""
136
+ <style>
137
+ .dataframe {font-size: 12px;}
138
+ </style>
139
+ """, unsafe_allow_html=True)
140
+
141
+ st.dataframe(
142
+ df,
143
+ use_container_width=True,
144
+ hide_index=True,
145
+ column_config={
146
+ "Model": st.column_config.TextColumn("Model", width="medium"),
147
+ "Category": st.column_config.TextColumn("Category", width="small"),
148
+ "Top Labels": st.column_config.TextColumn("Top Labels", width="large"),
149
+ "Scores": st.column_config.TextColumn("Scores", width="medium"),
150
+ }
151
+ )
152
+
153
+ csv = df.to_csv(index=False).encode('utf-8')
154
+ st.download_button(
155
+ "Download Results CSV",
156
+ csv,
157
+ "classification_results.csv",
158
+ "text/csv",
159
+ key='download-csv'
160
+ )