mustafa2ak's picture
Update app.py
2bd5ddb verified
import gradio as gr
import torch
import numpy as np
import random
import os
from PIL import Image
from model import PretrainedUNet
from dataset import LEVIRCDDataset
from inference import load_model, run_inference, create_overlay
from utils import setup_dataset
# Configuration
MODEL_PATH = "levid-cd-15.09_weights.pth"
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Initialize
print(f"Using device: {DEVICE}")
# Setup dataset
dataset_path = setup_dataset()
if dataset_path is None:
raise Exception("Failed to setup dataset")
# Load dataset
dataset = LEVIRCDDataset(dataset_path)
if len(dataset) == 0:
raise Exception("No valid images found in dataset")
# Load model
try:
model = load_model(MODEL_PATH, DEVICE)
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
model = None
def get_random_prediction():
"""Get random image and run prediction"""
if model is None:
return None, None, None, None
if len(dataset) == 0:
return None, None, None, None
# Get random sample
idx = random.randint(0, len(dataset) - 1)
sample = dataset.get_sample(idx)
# Run inference
pred_mask = run_inference(
model,
sample['img_a_tensor'],
sample['img_b_tensor'],
DEVICE
)
# Create visualizations
img_before = sample['img_a_orig']
img_after = sample['img_b_orig']
# Ground truth overlay (red) - FIX THIS PART
img_gt = create_overlay(
sample['img_b_orig'],
sample['mask_orig'],
[255, 100, 100],
threshold=0.5 # Changed from 128 to 0.5
)
# Prediction overlay (blue)
img_pred = create_overlay(
sample['img_b_orig'],
pred_mask,
[100, 100, 255],
threshold=0.5
)
return img_before, img_after, img_gt, img_pred
# Create Gradio interface
with gr.Blocks(title="Yapılaşma Değişim Tespiti", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🏗️ Yapay Zeka Yardımıyla Yapılaşmadaki Değişimlerin İzlenmesi
### 📖 Uygulama Hakkında
Bu yapay zekâ modeli, Amerika Birleşik Devletleri'nin **Texas eyaletinde, Austin şehrine ait uydu görüntüleri** kullanılarak eğitilmiştir.
Gösterilen tüm tahminler bu bölgeye aittir.
Sistem iki farklı sonucu karşılaştırmalı olarak sunar:
- **İnsan eliyle işaretlenmiş gerçek veriler** (Kırmızı renk ile gösterilir)
- **Yapay zekâ tarafından otomatik üretilen tahminler** (Mavi renk ile gösterilir)
""")
# Random button
with gr.Row():
btn = gr.Button(
"🎲 Rastgele Görsel Göster",
variant="primary",
scale=1
)
# Image grid - 2x2 layout
with gr.Row():
with gr.Column():
img_before = gr.Image(
label="📅 Önceki Görüntü",
type="pil",
height=300
)
with gr.Column():
img_after = gr.Image(
label="📅 Sonraki Görüntü",
type="pil",
height=300
)
with gr.Row():
with gr.Column():
img_gt = gr.Image(
label="👤 İnsan Eliyle İşaretlenmiş",
type="pil",
height=300
)
with gr.Column():
img_pred = gr.Image(
label="🤖 Yapay Zekâ Tahmini",
type="pil",
height=300
)
# Info footer
gr.Markdown("""
---
**Kaynak kodlara "Files" dosyası üzerinden ulaşabilirsiniz..
""")
# Connect button
btn.click(
fn=get_random_prediction,
inputs=None,
outputs=[img_before, img_after, img_gt, img_pred]
)
# Load initial image on startup
demo.load(
fn=get_random_prediction,
inputs=None,
outputs=[img_before, img_after, img_gt, img_pred]
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)