| import streamlit as st |
| from transformers import pipeline |
| from PIL import Image |
|
|
| MODEL_1 = "google/vit-base-patch16-224" |
| MIN_ACEPTABLE_SCORE = 0.1 |
| MAX_N_LABELS = 5 |
| MODEL_2 = "nateraw/vit-age-classifier" |
| MODELS = [ |
| "-- General Image Classification --", |
| "google/vit-base-patch16-224", |
| "microsoft/resnet-50", |
| "microsoft/resnet-18", |
| "microsoft/resnet-34", |
| "microsoft/resnet-101", |
| "microsoft/resnet-152", |
| "microsoft/swin-tiny-patch4-window7-224", |
| "microsoft/swinv2-base-patch4-window16-256", |
| "microsoft/beit-base-patch16-224-pt22k-ft22k", |
| "facebook/convnext-large-224", |
| "facebook/convnext-base-224-22k-1k", |
| "facebook/convnext-tiny-224", |
| "nvidia/mit-b0", |
| "timm/resnet50.a1_in1k", |
| "timm/tf_efficientnetv2_s.in21k", |
| "timm/convnext_tiny.fb_in22k", |
| "vit-base-patch16-224-in21k", |
| "facebook/deit-base-distilled-patch16-224 << new >>", |
| "WinKawaks/vit-tiny-patch16-224 << new >>", |
|
|
| "-- Age Classification --", |
| "nateraw/vit-age-classifier", |
|
|
| "-- NSFW Detection --", |
| "Falconsai/nsfw_image_detection", |
| "LukeJacob2023/nsfw-image-detector", |
| "carbon225/vit-base-patch16-224-hentai", |
| "Marqo/nsfw-image-detection-384 << new >>", |
|
|
| "-- Aesthetic/Art Classification --", |
| "cafeai/cafe_aesthetic", |
| "shadowlilac/aesthetic-shadow", |
| "pixai-labs/pixai-tagger-v0.9 << new >>", |
|
|
| "-- Face/Emotion Classification --", |
| "trpakov/vit-face-expression", |
| "RickyIG/emotion_face_image_classification", |
| "rizvandwiki/gender-classification", |
|
|
| "-- Food Classification --", |
| "nateraw/food", |
| "BinhQuocNguyen/food-recognition-model << new >>", |
|
|
| |
| |
| |
| |
|
|
| "-- AI vs Human Detection --", |
| "Ateeqq/ai-vs-human-image-detector << new >>", |
| "umm-maybe/AI-image-detector << new >>", |
|
|
| "-- Deepfake Detection --", |
| "not-lain/deepfake", |
|
|
| "-- Anime/Manga Classification --", |
| |
|
|
| "-- Human Activity Recognition --", |
| "DunnBC22/vit-base-patch16-224-in21k_Human_Activity_Recognition", |
|
|
| "-- Clothing/Fashion --", |
| "aalonso-developer/vit-base-patch16-224-in21k-clothing-classifier", |
|
|
| "-- Real Estate --", |
| "andupets/real-estate-image-classification", |
|
|
| "-- Satellite/Remote Sensing --", |
| "FatihC/swin-tiny-patch4-window7-224-finetuned-eurosat-watermark", |
|
|
| "-- Car Classification --", |
| "lamnt2008/car_brands_classification << new >>", |
|
|
| "-- Document Classification --", |
| "docling-project/DocumentFigureClassifier-v2.5 << new >>", |
|
|
| "-- EfficientNet (timm) --", |
| "timm/efficientnet_b0.ra_in1k << new >>", |
| "timm/mobilenetv3_large_100.ra_in1k", |
| "timm/mobilenetv3_small_100.lamb_in1k << new >>", |
|
|
| "-- Experimental/Future --", |
| "#q-future/one-align", |
| ] |
|
|
| def classify(image, model): |
| if model.startswith("--") or model.startswith("#"): |
| st.warning("Please select a valid model from the list") |
| return [] |
| model_name = model.replace(" << new >>", "") |
| classifier = pipeline("image-classification", model=model_name) |
| result= classifier(image) |
| return result |
|
|
| def save_result(result): |
| st.write("In the future, this function will save the result in a database.") |
|
|
| def print_result(result): |
|
|
| comulative_discarded_score = 0 |
| for i in range(len(result)): |
| if result[i]['score'] < MIN_ACEPTABLE_SCORE: |
| comulative_discarded_score += result[i]['score'] |
| else: |
| st.write(result[i]['label']) |
| st.progress(result[i]['score']) |
| st.write(result[i]['score']) |
|
|
| st.write(f"comulative_discarded_score:") |
| st.progress(comulative_discarded_score) |
| st.write(comulative_discarded_score) |
| |
|
|
|
|
| def main(): |
| st.title("Image Classification") |
| st.write("This is a simple web app to test and compare different image classifier models using Hugging Face's image-classification pipeline.") |
| st.markdown(":white_check_mark: **:green[22 new models added!]** - Including Medical, AI vs Human detection, Anime classification and more.") |
| st.write("From time to time more models will be added to the list. If you want to add a model, please open an issue on the GitHub repository.") |
| st.write("If you like this project, please consider liking it or buying me a coffee. It will help me to keep working on this and other projects. Thank you!") |
|
|
| |
| bmc_link = "https://www.buymeacoffee.com/nuno.tome" |
| |
| image_url = "https://i.giphy.com/RETzc1mj7HpZPuNf3e.webp" |
| |
| image_size = "150px" |
| |
| image_link_markdown = f"[]({bmc_link})" |
|
|
| |
|
|
| st.markdown(image_link_markdown, unsafe_allow_html=True) |
| |
| |
| |
| |
| input_image = st.file_uploader("Upload Image") |
| shosen_model = st.selectbox("Select the model to use", MODELS) |
| |
| |
| if input_image is not None: |
| image_to_classify = Image.open(input_image) |
| st.image(image_to_classify, caption="Uploaded Image") |
| if st.button("Classify"): |
| if shosen_model.startswith("--") or shosen_model.startswith("#"): |
| st.warning("Please select a valid model from the list") |
| else: |
| image_to_classify = Image.open(input_image) |
| classification_obj1=[] |
| |
| classification_result = classify(image_to_classify, shosen_model) |
| classification_obj1.append(classification_result) |
| print_result(classification_result) |
| save_result(classification_result) |
|
|
|
|
| if __name__ == "__main__": |
| main() |