import torch import torch.nn as nn from torchvision import transforms from PIL import Image import gradio as gr import requests from io import BytesIO # ✅ Define model class SimpleCNN(nn.Module): def __init__(self): super().__init__() self.conv = nn.Sequential( nn.Conv2d(3, 16, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(16, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2), ) self.fc = nn.Sequential( nn.Flatten(), nn.Linear(64 * 32 * 32, 128), nn.ReLU(), nn.Linear(128, 3) ) def forward(self, x): x = self.conv(x) x = self.fc(x) return x # ✅ Load model device = "cuda" if torch.cuda.is_available() else "cpu" model = SimpleCNN().to(device) model.load_state_dict(torch.load("sid_cnn_model.pt", map_location=device)) model.eval() # ✅ Labels & Transform LABELS = {0: "🟩 Real", 1: "🤖 AI-generated", 2: "🧩 Tampered"} transform = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()]) # ✅ Prediction function def predict_image(image_url=None, uploaded_image=None): try: if image_url: response = requests.get(image_url, timeout=10) image = Image.open(BytesIO(response.content)).convert("RGB") elif uploaded_image is not None: image = uploaded_image.convert("RGB") else: return "⚠️ Please provide an image URL or upload an image." img_tensor = transform(image).unsqueeze(0).to(device) with torch.no_grad(): outputs = model(img_tensor) pred = outputs.argmax(1).item() label = LABELS[pred] return f"🧠 Prediction: {label}" except Exception as e: return f"❌ Error: {str(e)}" # ✅ Gradio app with API route with gr.Blocks() as app: gr.Markdown("## 🧠 AI Image Detector") url_input = gr.Textbox(label="Image URL") upload_input = gr.Image(label="Upload Image", type="pil") output = gr.Textbox(label="Result") detect_btn = gr.Button("Detect") detect_btn.click(predict_image, [url_input, upload_input], output) # ✅ Expose public API route gr.Interface(fn=predict_image, inputs=[gr.Textbox(), gr.Image(type="pil")], outputs="text").api_name = "predict" # ✅ Launch app.launch(server_name="0.0.0.0", server_port=7860, share=True)