Spaces:
Configuration error
Configuration error
| import gradio as gr | |
| from transformers import ViTFeatureExtractor, ViTForImageClassification | |
| from PIL import Image | |
| import torch | |
| # ******************************************************************* | |
| # ЕҢ ТАЗА ЖӘНЕ ҚУАТТЫ МОДЕЛЬ ID-І (90%+ Accuracy, таза PyTorch) | |
| # ******************************************************************* | |
| MODEL_ID = "keremberke/vit-base-patch16-224-full-empty-trash-bin" | |
| CLASS_NAMES = ['Empty', 'Full'] | |
| try: | |
| feature_extractor = ViTFeatureExtractor.from_pretrained(MODEL_ID) | |
| model = ViTForImageClassification.from_pretrained(MODEL_ID) | |
| MODEL_LOADED = True | |
| # Модельді 2 классқа бейімдейміз (қате болмауы үшін) | |
| if model.config.id2label: | |
| CLASS_NAMES = [model.config.id2label[i] for i in model.config.id2label] | |
| except Exception as e: | |
| print(f"ERROR: Model loading failed: {e}") | |
| MODEL_LOADED = False | |
| def classify_trash_bin(image): | |
| if not MODEL_LOADED: | |
| return {"Error": 1.0, "Check Logs": 0.0} | |
| if image is None: | |
| return {CLASS_NAMES[0]: 0.5, CLASS_NAMES[1]: 0.5} | |
| try: | |
| img = Image.fromarray(image).convert("RGB") | |
| inputs = feature_extractor(images=img, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| probabilities = torch.softmax(logits, dim=1).squeeze().tolist() | |
| # Тек алғашқы 2 класты қайтару | |
| if len(probabilities) > 2: | |
| probabilities = probabilities[:2] | |
| results = {CLASS_NAMES[i]: float(probabilities[i]) for i in range(len(CLASS_NAMES))} | |
| return results | |
| except Exception as e: | |
| return {"Error": 1.0, "Check Logs": 0.0} | |
| # Gradio интерфейсін құру | |
| iface = gr.Interface( | |
| fn=classify_trash_bin, | |
| inputs=gr.Image(type="numpy", label="SmartTrachAI Input"), | |
| outputs=gr.Label(num_top_classes=2, label="Prediction"), | |
| title="SmartTrachAI", | |
| description="Automated Trash Bin Status Detector." | |
| ) | |
| iface.launch() |