File size: 2,266 Bytes
c4ed1c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from PIL import Image
from io import BytesIO
from google import genai
from google.genai import types
from dotenv import load_dotenv
from transforms import transform, target_transform
from torchvision import models
from hf_class import MyResNet18
import torch
import torch.nn.functional as F
import os
import gradio as gr

load_dotenv()

index_to_label = {
    0: 'battery',
    1: 'biological',
    2: 'cardboard',
    3: 'clothes',
    4: 'glass',
    5: 'metal',
    6: 'paper',
    7: 'plastic',
    8: 'shoes',
    9: 'trash',
}

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MyResNet18.from_pretrained('harriskr14/WasteNET')
model = model.to(device)
model.eval()

def predict_image(img: Image):
    img_tensor = transform['test'](img).unsqueeze(0)
    img_tensor = img_tensor.to(device)
    output = model(img_tensor)
    softmax = F.softmax(output, dim=1)
    _, preds = torch.max(output, dim=1)
    predicted_label = index_to_label[preds.item()]
    confidence_score = float('{:.4f}'.format(softmax[0][preds.item()].item()))

    client = genai.Client(api_key=os.getenv('GOOGLE_API_KEY'))

    if predicted_label != 'trash':
        response = client.models.generate_content(
            model="gemini-2.5-flash",
            contents=f"Mention 5 ways of how to manage {predicted_label} waste consicely!",
        )
    else:
        buffer = BytesIO()
        img.save(buffer, format="PNG")
        image_bytes = buffer.getvalue()
        response = client.models.generate_content(
            model="gemini-2.5-flash",
            contents=["Mention 5 ways of how to manage this waste concisely!", types.Part.from_bytes(data=image_bytes, mime_type="image/png")],
        )

    return predicted_label, confidence_score, response.text

demo = gr.Interface(
    fn=predict_image,
    inputs=gr.Image(type="pil"),
    outputs=[
        gr.Label(num_top_classes=1, label="Predicted Label"),
        gr.Number(label="Confidence Score"),
        gr.Markdown(label="Waste Management Suggestions", show_label=True, container=True),
    ],
    title="Waste Classification and Management Suggestion",
    description="Upload an image of garbage to classify its type and get suggestions on how to manage it.",
)

demo.launch(share=True)