SmartCity / app.py
Nonabzbssbbsbs's picture
Create app.py
f1a734d verified
import gradio as gr
from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image
import torch
# *******************************************************************
# ЕҢ ТАЗА ЖӘНЕ ҚУАТТЫ МОДЕЛЬ ID-І (90%+ Accuracy, таза PyTorch)
# *******************************************************************
MODEL_ID = "keremberke/vit-base-patch16-224-full-empty-trash-bin"
CLASS_NAMES = ['Empty', 'Full']
try:
feature_extractor = ViTFeatureExtractor.from_pretrained(MODEL_ID)
model = ViTForImageClassification.from_pretrained(MODEL_ID)
MODEL_LOADED = True
# Модельді 2 классқа бейімдейміз (қате болмауы үшін)
if model.config.id2label:
CLASS_NAMES = [model.config.id2label[i] for i in model.config.id2label]
except Exception as e:
print(f"ERROR: Model loading failed: {e}")
MODEL_LOADED = False
def classify_trash_bin(image):
if not MODEL_LOADED:
return {"Error": 1.0, "Check Logs": 0.0}
if image is None:
return {CLASS_NAMES[0]: 0.5, CLASS_NAMES[1]: 0.5}
try:
img = Image.fromarray(image).convert("RGB")
inputs = feature_extractor(images=img, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probabilities = torch.softmax(logits, dim=1).squeeze().tolist()
# Тек алғашқы 2 класты қайтару
if len(probabilities) > 2:
probabilities = probabilities[:2]
results = {CLASS_NAMES[i]: float(probabilities[i]) for i in range(len(CLASS_NAMES))}
return results
except Exception as e:
return {"Error": 1.0, "Check Logs": 0.0}
# Gradio интерфейсін құру
iface = gr.Interface(
fn=classify_trash_bin,
inputs=gr.Image(type="numpy", label="SmartTrachAI Input"),
outputs=gr.Label(num_top_classes=2, label="Prediction"),
title="SmartTrachAI",
description="Automated Trash Bin Status Detector."
)
iface.launch()