File size: 2,112 Bytes
f1a734d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import gradio as gr
from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image
import torch

# *******************************************************************
# ЕҢ ТАЗА ЖӘНЕ ҚУАТТЫ МОДЕЛЬ ID-І (90%+ Accuracy, таза PyTorch)
# *******************************************************************
MODEL_ID = "keremberke/vit-base-patch16-224-full-empty-trash-bin" 
CLASS_NAMES = ['Empty', 'Full'] 

try:
    feature_extractor = ViTFeatureExtractor.from_pretrained(MODEL_ID)
    model = ViTForImageClassification.from_pretrained(MODEL_ID)
    MODEL_LOADED = True
    # Модельді 2 классқа бейімдейміз (қате болмауы үшін)
    if model.config.id2label:
        CLASS_NAMES = [model.config.id2label[i] for i in model.config.id2label]

except Exception as e:
    print(f"ERROR: Model loading failed: {e}")
    MODEL_LOADED = False

def classify_trash_bin(image):
    if not MODEL_LOADED:
        return {"Error": 1.0, "Check Logs": 0.0}
    
    if image is None:
        return {CLASS_NAMES[0]: 0.5, CLASS_NAMES[1]: 0.5}
    
    try:
        img = Image.fromarray(image).convert("RGB")
        inputs = feature_extractor(images=img, return_tensors="pt")
        
        with torch.no_grad():
            outputs = model(**inputs)
        
        logits = outputs.logits
        probabilities = torch.softmax(logits, dim=1).squeeze().tolist()
        
        # Тек алғашқы 2 класты қайтару
        if len(probabilities) > 2:
             probabilities = probabilities[:2] 
        
        results = {CLASS_NAMES[i]: float(probabilities[i]) for i in range(len(CLASS_NAMES))}
        return results

    except Exception as e:
        return {"Error": 1.0, "Check Logs": 0.0}

# Gradio интерфейсін құру
iface = gr.Interface(
    fn=classify_trash_bin,
    inputs=gr.Image(type="numpy", label="SmartTrachAI Input"),
    outputs=gr.Label(num_top_classes=2, label="Prediction"),
    title="SmartTrachAI", 
    description="Automated Trash Bin Status Detector."
)

iface.launch()