sarraarab's picture
Update app.py
a1d5103 verified
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import gradio as gr
import requests
from io import BytesIO
# βœ… Define model
class SimpleCNN(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(3, 16, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(16, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
)
self.fc = nn.Sequential(
nn.Flatten(),
nn.Linear(64 * 32 * 32, 128),
nn.ReLU(),
nn.Linear(128, 3)
)
def forward(self, x):
x = self.conv(x)
x = self.fc(x)
return x
# βœ… Load model
device = "cuda" if torch.cuda.is_available() else "cpu"
model = SimpleCNN().to(device)
model.load_state_dict(torch.load("sid_cnn_model.pt", map_location=device))
model.eval()
# βœ… Labels & Transform
LABELS = {0: "🟩 Real", 1: "πŸ€– AI-generated", 2: "🧩 Tampered"}
transform = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()])
# βœ… Prediction function
def predict_image(image_url=None, uploaded_image=None):
try:
if image_url:
response = requests.get(image_url, timeout=10)
image = Image.open(BytesIO(response.content)).convert("RGB")
elif uploaded_image is not None:
image = uploaded_image.convert("RGB")
else:
return "⚠️ Please provide an image URL or upload an image."
img_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(img_tensor)
pred = outputs.argmax(1).item()
label = LABELS[pred]
return f"🧠 Prediction: {label}"
except Exception as e:
return f"❌ Error: {str(e)}"
# βœ… Gradio app with API route
with gr.Blocks() as app:
gr.Markdown("## 🧠 AI Image Detector")
url_input = gr.Textbox(label="Image URL")
upload_input = gr.Image(label="Upload Image", type="pil")
output = gr.Textbox(label="Result")
detect_btn = gr.Button("Detect")
detect_btn.click(predict_image, [url_input, upload_input], output)
# βœ… Expose public API route
gr.Interface(fn=predict_image,
inputs=[gr.Textbox(), gr.Image(type="pil")],
outputs="text").api_name = "predict"
# βœ… Launch
app.launch(server_name="0.0.0.0", server_port=7860, share=True)