Spaces:
Sleeping
Sleeping
| import torch | |
| import gradio as gr | |
| from transformers import ViTImageProcessor, AutoModelForImageClassification | |
| from PIL import Image | |
| import numpy as np | |
| import time | |
| # ----------------------------- | |
| # Configuration and Setup | |
| # ----------------------------- | |
| # Force Gradio to use CUDA (if available) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| # Model path | |
| model_path = "final_model" | |
| # Load image processor and model | |
| try: | |
| print("Loading image processor...") | |
| processor = ViTImageProcessor.from_pretrained(model_path) | |
| print("Loading model...") | |
| model = AutoModelForImageClassification.from_pretrained(model_path) | |
| model = model.to(device) | |
| model.eval() # Important for deterministic behavior | |
| except Exception as e: | |
| raise RuntimeError(f"Error loading model: {e}") | |
| # Attempt to load label mappings | |
| try: | |
| labels = model.config.id2label | |
| assert isinstance(labels, dict) and len(labels) > 0, "Invalid or empty id2label mapping" | |
| except Exception as e: | |
| print(f"β οΈ Labels not found in model config: {e}") | |
| labels = {i: f"Class {i}" for i in range(model.config.num_labels)} | |
| # ----------------------------- | |
| # Standalone Test Mode (Optional) | |
| # ----------------------------- | |
| def test_inference(): | |
| """Run inference outside Gradio to verify model works""" | |
| dummy_img = Image.new('RGB', (224, 224), color='red') # Create a dummy image | |
| print("Running standalone inference test...") | |
| try: | |
| inputs = processor(images=dummy_img, return_tensors="pt").to(device) | |
| with torch.inference_mode(): | |
| outputs = model(**inputs) | |
| print("β Model inference test successful") | |
| except Exception as e: | |
| print(f"β Inference test failed: {e}") | |
| # ----------------------------- | |
| # Prediction Function | |
| # ----------------------------- | |
| def predict(image): | |
| if image is None: | |
| return "No image uploaded." | |
| print("\n[INFO] Starting prediction pipeline...") | |
| # Step 1: Preprocessing | |
| print("[STEP 1] Preprocessing image...") | |
| try: | |
| start = time.time() | |
| inputs = processor(images=image, return_tensors="pt").to(device) | |
| print(f"[DEBUG] Input shape: {inputs['pixel_values'].shape}") | |
| print(f"[DEBUG] Time taken: {time.time() - start:.2f}s") | |
| except Exception as e: | |
| return f"β Error in preprocessing: {e}" | |
| # Step 2: Inference | |
| print("[STEP 2] Running inference...") | |
| try: | |
| start = time.time() | |
| with torch.inference_mode(): | |
| outputs = model(**inputs) | |
| print(f"[DEBUG] Inference completed in {time.time() - start:.2f}s") | |
| except Exception as e: | |
| return f"β Error in model inference: {e}" | |
| # Step 3: Post-processing | |
| print("[STEP 3] Processing output...") | |
| try: | |
| probs = torch.nn.functional.softmax(outputs.logits, dim=1) | |
| top5_probs, top5_indices = torch.topk(probs, 5) | |
| result = "" | |
| for i in range(5): | |
| idx = top5_indices[0][i].item() | |
| label = labels.get(idx, f"Unknown class {idx}") | |
| prob = top5_probs[0][i].item() * 100 | |
| result += f"{i + 1}. {label} β {prob:.2f}%\n" | |
| except Exception as e: | |
| return f"β Error post-processing: {e}" | |
| print("[INFO] Prediction complete β \n") | |
| return result.strip() | |
| # ----------------------------- | |
| # Gradio Interface | |
| # ----------------------------- | |
| interface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil", label="Upload an Image"), | |
| outputs=gr.Textbox(label="Top 5 Predictions"), | |
| title="Fine-Tuned ViT Image Classifier", | |
| description="Upload an image to get the top 5 predicted classes with confidence scores.", | |
| allow_flagging="never", | |
| examples=[["examples/test_image.jpg"]] if "examples" in locals() else None | |
| ) | |
| if __name__ == "__main__": | |
| print("\nπ Launching Gradio interface...\n") | |
| test_inference() # Optional: Run test before launching | |
| interface.launch(share=True) |