from PIL import Image from io import BytesIO from google import genai from google.genai import types from dotenv import load_dotenv from transforms import transform, target_transform from torchvision import models from hf_class import MyResNet18 import torch import torch.nn.functional as F import os import gradio as gr load_dotenv() index_to_label = { 0: 'battery', 1: 'biological', 2: 'cardboard', 3: 'clothes', 4: 'glass', 5: 'metal', 6: 'paper', 7: 'plastic', 8: 'shoes', 9: 'trash', } device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = MyResNet18.from_pretrained('harriskr14/WasteNET') model = model.to(device) model.eval() def predict_image(img: Image): img_tensor = transform['test'](img).unsqueeze(0) img_tensor = img_tensor.to(device) output = model(img_tensor) softmax = F.softmax(output, dim=1) _, preds = torch.max(output, dim=1) predicted_label = index_to_label[preds.item()] confidence_score = float('{:.4f}'.format(softmax[0][preds.item()].item())) client = genai.Client(api_key=os.getenv('GOOGLE_API_KEY')) if predicted_label != 'trash': response = client.models.generate_content( model="gemini-2.5-flash", contents=f"Mention 5 ways of how to manage {predicted_label} waste consicely!", ) else: buffer = BytesIO() img.save(buffer, format="PNG") image_bytes = buffer.getvalue() response = client.models.generate_content( model="gemini-2.5-flash", contents=["Mention 5 ways of how to manage this waste concisely!", types.Part.from_bytes(data=image_bytes, mime_type="image/png")], ) return predicted_label, confidence_score, response.text demo = gr.Interface( fn=predict_image, inputs=gr.Image(type="pil"), outputs=[ gr.Label(num_top_classes=1, label="Predicted Label"), gr.Number(label="Confidence Score"), gr.Markdown(label="Waste Management Suggestions", show_label=True, container=True), ], title="Waste Classification and Management Suggestion", description="Upload an image of garbage to classify its type and get suggestions on how to manage it.", ) demo.launch(share=True)