# app.py import gradio as gr from transformers import AutoModelForImageClassification, AutoImageProcessor from PIL import Image import torch MODEL_LIST = [ "prithivMLmods/Trash-Net", "yangy50/garbage-classification" ] models = [] processors = [] devices = [] print("Loading models...") for model_name in MODEL_LIST: try: processor = AutoImageProcessor.from_pretrained(model_name) model = AutoModelForImageClassification.from_pretrained( model_name, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto" if torch.cuda.is_available() else None ) model.eval() processors.append(processor) models.append(model) devices.append(next(model.parameters()).device) print(f"Loaded: {model_name}") except Exception as e: print(f"Failed to load {model_name}: {e}") def classify_image(image: Image.Image): results = {} for model_name, processor, model, device in zip(MODEL_LIST, processors, models, devices): try: inputs = processor(images=image, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = model(**inputs) pred = outputs.logits.argmax(-1).item() label = model.config.id2label[pred] results[model_name] = label except Exception as e: results[model_name] = f"error:{e}" # 输出每个模型的结果 results_text = "\n".join([f"{name}: {label}" for name, label in results.items()]) # 以 yangy50/garbage-classification 为最终结果 final_label = results.get("yangy50/garbage-classification", "Unknown") results_text += f"\n\nFinal Label (base yangy50): {final_label}" return results_text iface = gr.Interface( fn=classify_image, inputs=gr.Image(type="pil", label="Upload Image"), outputs=[gr.Textbox(label="Model Predictions")], title="Trash Classification", description=( "Upload an image, and the following models will classify it:\n" "1. prithivMLmods/Trash-Net\n" "2. yangy50/garbage-classification\n" "The final label is based on yangy50/garbage-classification." ) ) if __name__ == "__main__": iface.launch()