WasteNET / app.py
harriskr14's picture
Initial commit - clean version
c4ed1c5
from PIL import Image
from io import BytesIO
from google import genai
from google.genai import types
from dotenv import load_dotenv
from transforms import transform, target_transform
from torchvision import models
from hf_class import MyResNet18
import torch
import torch.nn.functional as F
import os
import gradio as gr
load_dotenv()
index_to_label = {
0: 'battery',
1: 'biological',
2: 'cardboard',
3: 'clothes',
4: 'glass',
5: 'metal',
6: 'paper',
7: 'plastic',
8: 'shoes',
9: 'trash',
}
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MyResNet18.from_pretrained('harriskr14/WasteNET')
model = model.to(device)
model.eval()
def predict_image(img: Image):
img_tensor = transform['test'](img).unsqueeze(0)
img_tensor = img_tensor.to(device)
output = model(img_tensor)
softmax = F.softmax(output, dim=1)
_, preds = torch.max(output, dim=1)
predicted_label = index_to_label[preds.item()]
confidence_score = float('{:.4f}'.format(softmax[0][preds.item()].item()))
client = genai.Client(api_key=os.getenv('GOOGLE_API_KEY'))
if predicted_label != 'trash':
response = client.models.generate_content(
model="gemini-2.5-flash",
contents=f"Mention 5 ways of how to manage {predicted_label} waste consicely!",
)
else:
buffer = BytesIO()
img.save(buffer, format="PNG")
image_bytes = buffer.getvalue()
response = client.models.generate_content(
model="gemini-2.5-flash",
contents=["Mention 5 ways of how to manage this waste concisely!", types.Part.from_bytes(data=image_bytes, mime_type="image/png")],
)
return predicted_label, confidence_score, response.text
demo = gr.Interface(
fn=predict_image,
inputs=gr.Image(type="pil"),
outputs=[
gr.Label(num_top_classes=1, label="Predicted Label"),
gr.Number(label="Confidence Score"),
gr.Markdown(label="Waste Management Suggestions", show_label=True, container=True),
],
title="Waste Classification and Management Suggestion",
description="Upload an image of garbage to classify its type and get suggestions on how to manage it.",
)
demo.launch(share=True)