Afeefa123's picture
Update app.py
f185c27 verified
import gradio as gr
import torch
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
from sklearn.ensemble import RandomForestClassifier
import joblib
# -------------------------------
# Load Pretrained Models
# -------------------------------
model_A = torch.hub.load('pytorch/vision', 'mobilenet_v2', pretrained=True)
model_B = torch.hub.load('pytorch/vision', 'resnet18', pretrained=True)
model_C = torch.hub.load('pytorch/vision', 'efficientnet_b0', pretrained=True)
model_A.eval()
model_B.eval()
model_C.eval()
# -------------------------------
# Image Transform
# -------------------------------
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
# -------------------------------
# Load / Create RF Model
# -------------------------------
try:
rf_model = joblib.load("rf_model.pkl")
except:
# Dummy training (you can replace later)
X = np.random.rand(200, 3)
y = np.random.randint(0, 2, 200)
rf_model = RandomForestClassifier(
n_estimators=50,
max_depth=4,
min_samples_leaf=10,
class_weight="balanced"
)
rf_model.fit(X, y)
joblib.dump(rf_model, "rf_model.pkl")
# -------------------------------
# Utility Functions
# -------------------------------
def get_features(output):
probs = torch.softmax(output, dim=1).detach().numpy()[0]
confidence = np.max(probs)
entropy = -np.sum(probs * np.log(probs + 1e-10))
top2 = np.sort(probs)[-2:]
margin = top2[1] - top2[0]
return confidence, entropy, margin
def predict_with_model(model, img):
with torch.no_grad():
output = model(img)
return output
# -------------------------------
# Main Pipeline
# -------------------------------
def idk_pipeline(image):
img = transform(image).unsqueeze(0)
# Step 1: Model A
out_A = predict_with_model(model_A, img)
conf, entropy, margin = get_features(out_A)
# If confident β†’ return
if conf > 0.7:
pred = torch.argmax(out_A).item()
return f"Model A Prediction βœ…\nClass: {pred}\nConfidence: {conf:.3f}"
# Step 2: RF decides skip or not
features = np.array([[conf, entropy, margin]])
decision = rf_model.predict(features)[0]
# 0 = skip B β†’ go to C
if decision == 0:
out_C = predict_with_model(model_C, img)
pred = torch.argmax(out_C).item()
return f"Skipped Model B πŸš€\nUsed Model C\nClass: {pred}"
# 1 = use B
else:
out_B = predict_with_model(model_B, img)
conf_B = torch.softmax(out_B, dim=1).max().item()
if conf_B > 0.7:
pred = torch.argmax(out_B).item()
return f"Model B Prediction βœ…\nClass: {pred}"
else:
out_C = predict_with_model(model_C, img)
pred = torch.argmax(out_C).item()
return f"Fallback to Model C πŸ”„\nClass: {pred}"
# -------------------------------
# Gradio UI
# -------------------------------
interface = gr.Interface(
fn=idk_pipeline,
inputs=gr.Image(type="pil"),
outputs="text",
title="IDK Cascade with Random Forest Skipping",
description="Implements dynamic skipping using Random Forest as described in research paper."
)
interface.launch()