File size: 4,519 Bytes
0171016
7ef7e24
 
004c451
7ef7e24
 
032e591
7ef7e24
 
 
0171016
568d41a
7ef7e24
677f68c
e49e74e
 
 
7ef7e24
744eb1f
 
 
 
012de12
 
 
 
744eb1f
 
012de12
 
744eb1f
 
 
 
 
 
 
 
 
 
 
 
c998360
 
 
7ef7e24
 
 
 
 
0171016
7ef7e24
 
 
 
 
 
 
004c451
 
7ef7e24
004c451
7ef7e24
 
 
 
0171016
7ef7e24
004c451
 
 
 
 
7ef7e24
004c451
7ef7e24
 
 
 
004c451
7ef7e24
004c451
 
 
7ef7e24
 
 
 
 
004c451
7ef7e24
 
 
004c451
7ef7e24
 
0171016
7ef7e24
 
004c451
7ef7e24
 
 
 
 
 
 
 
 
0171016
 
7ef7e24
 
 
 
 
0171016
 
 
7ef7e24
 
 
 
 
004c451
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# app.py – Hugging Face Spaces avec liste dynamique depuis EC2 + export CSV

import requests
from requests.exceptions import RequestException
import gradio as gr
import base64
from PIL import Image, ImageDraw
import io
import numpy as np
from collections import deque
import pandas as pd
import tempfile


API_URL = "http://18.212.167.3:8000/predict"
IMAGE_BASE_URL = "http://18.212.167.3:8000/test_images/"
IMAGE_LIST_URL = "http://18.212.167.3:8000/list-test-images"

def generate_dummy_legend():

    class_names = ["Flat", "Construction", "Object", "Nature", "Sky", "Human", "Vehicle", "Ignore"]
    colors = [
        (70, 70, 70),  # Flat
        (128, 64, 128),    # Construction
        (107, 142, 35),     # Object
        (0, 0, 0),  # Nature
        (70, 130, 180),  # Sky
        (220, 20, 60),   # Human
        (0, 0, 142),     # Vehicle
        (102, 102, 156)        # Ignore
    ]

    img = Image.new("RGB", (320, 160), color=(255, 255, 255))
    draw = ImageDraw.Draw(img)

    for i, (label, color) in enumerate(zip(class_names, colors)):
        y = i * 20
        draw.rectangle([5, y + 5, 25, y + 20], fill=color)
        draw.text((30, y + 5), label, fill=(0, 0, 0))

    return img

legend_img = generate_dummy_legend()
history = deque(maxlen=5)

def overlay_mask_on_image(image: Image.Image, mask: Image.Image, alpha: float = 0.5) -> Image.Image:
    image = image.convert("RGBA").resize(mask.size)
    mask = mask.convert("RGBA")
    return Image.blend(image, mask, alpha=alpha)

def segment_image(image: Image.Image, source_name="image_upload.png"):
    try:
        if image is None:
            raise ValueError("Aucune image fournie.")
        buffered = io.BytesIO()
        image.save(buffered, format="PNG")
        buffered.seek(0)
        files = {"file": ("input.png", buffered, "image/png")}

        response = requests.post(API_URL, files=files, timeout=10)
        response.raise_for_status()

        data = response.json()
        mask_bytes = base64.b64decode(data["mask_base64"])
        mask_image = Image.open(io.BytesIO(mask_bytes)).convert("RGB")
        overlay_img = overlay_mask_on_image(image, mask_image)

        return image, mask_image, overlay_img, f"{data['inference_time']} sec"

    except RequestException as e:
        print(f"[ERREUR API EC2] {e}")
        return None, None, None, "Erreur : EC2 injoignable ou lente"

    except Exception as e:
        print(f"[ERREUR GRADIO] {e}")
        return None, None, None, f"Erreur : {str(e)}"

def get_remote_image_names():
    try:
        response = requests.get(IMAGE_LIST_URL, timeout=5)
        response.raise_for_status()
        return response.json().get("files", [])
    except Exception as e:
        print(f"[ERREUR LISTE FICHIERS] {e}")
        return ["Erreur lors du chargement des noms"]

def load_image_from_url(filename):
    try:
        url = IMAGE_BASE_URL + filename
        response = requests.get(url, timeout=5)
        response.raise_for_status()
        return Image.open(io.BytesIO(response.content)).convert("RGB")
    except Exception as e:
        print(f"[ERREUR CHARGEMENT IMAGE {filename}] {e}")
        return None


with gr.Blocks(title="Segmentation d'Images Urbaines") as demo:
    gr.Markdown("# 🧠 Segmentation d'Images Urbaines")
    gr.Markdown("Upload une image ou sélectionne une image distante depuis EC2")

    with gr.Row():
        input_image = gr.Image(type="pil", label="Image d'entrée")
        btn = gr.Button("Segmenter")
    with gr.Row():
        img_original = gr.Image(label="Image originale")
        img_mask = gr.Image(label="Mask prédit")
        img_overlay = gr.Image(label="Superposition (image + mask)")
        inf_time = gr.Textbox(label="Temps d'inférence")
    btn.click(fn=lambda img: segment_image(img), inputs=input_image,
              outputs=[img_original, img_mask, img_overlay, inf_time])

    gr.Markdown("## 🌍 Ou sélectionne une image hébergée sur EC2 :")
    with gr.Row():
        dropdown = gr.Dropdown(label="Fichier distant (EC2)", choices=get_remote_image_names())
        btn_url = gr.Button("Segmenter cette image distante")
    btn_url.click(fn=lambda name: segment_image(load_image_from_url(name), source_name=name),
                  inputs=dropdown, outputs=[img_original, img_mask, img_overlay, inf_time])


    with gr.Accordion("🎨 Légende des classes", open=False):
        gr.Image(value=legend_img, label="Légende (fictive ici)", interactive=False)

if __name__ == "__main__":
    demo.launch()