Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms | |
| from PIL import Image | |
| import gradio as gr | |
| import requests | |
| from io import BytesIO | |
| # β Define model | |
| class SimpleCNN(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.conv = nn.Sequential( | |
| nn.Conv2d(3, 16, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2), | |
| nn.Conv2d(16, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2), | |
| nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2), | |
| ) | |
| self.fc = nn.Sequential( | |
| nn.Flatten(), | |
| nn.Linear(64 * 32 * 32, 128), | |
| nn.ReLU(), | |
| nn.Linear(128, 3) | |
| ) | |
| def forward(self, x): | |
| x = self.conv(x) | |
| x = self.fc(x) | |
| return x | |
| # β Load model | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = SimpleCNN().to(device) | |
| model.load_state_dict(torch.load("sid_cnn_model.pt", map_location=device)) | |
| model.eval() | |
| # β Labels & Transform | |
| LABELS = {0: "π© Real", 1: "π€ AI-generated", 2: "π§© Tampered"} | |
| transform = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()]) | |
| # β Prediction function | |
| def predict_image(image_url=None, uploaded_image=None): | |
| try: | |
| if image_url: | |
| response = requests.get(image_url, timeout=10) | |
| image = Image.open(BytesIO(response.content)).convert("RGB") | |
| elif uploaded_image is not None: | |
| image = uploaded_image.convert("RGB") | |
| else: | |
| return "β οΈ Please provide an image URL or upload an image." | |
| img_tensor = transform(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| outputs = model(img_tensor) | |
| pred = outputs.argmax(1).item() | |
| label = LABELS[pred] | |
| return f"π§ Prediction: {label}" | |
| except Exception as e: | |
| return f"β Error: {str(e)}" | |
| # β Gradio app with API route | |
| with gr.Blocks() as app: | |
| gr.Markdown("## π§ AI Image Detector") | |
| url_input = gr.Textbox(label="Image URL") | |
| upload_input = gr.Image(label="Upload Image", type="pil") | |
| output = gr.Textbox(label="Result") | |
| detect_btn = gr.Button("Detect") | |
| detect_btn.click(predict_image, [url_input, upload_input], output) | |
| # β Expose public API route | |
| gr.Interface(fn=predict_image, | |
| inputs=[gr.Textbox(), gr.Image(type="pil")], | |
| outputs="text").api_name = "predict" | |
| # β Launch | |
| app.launch(server_name="0.0.0.0", server_port=7860, share=True) | |