| | from PIL import Image, ImageDraw, ImageFont |
| | import os |
| | import base64 |
| | import json |
| | import requests |
| | from io import BytesIO |
| | import threading |
| | from datetime import datetime |
| | import paho.mqtt.client as mqtt |
| | import gradio as gr |
| | from api import predict_image |
| |
|
| | |
| | IMAGE_PATH = "received_image.jpg" |
| | IMAGE_HISTORY_DIR = "image_history" |
| | MAX_HISTORY_SIZE = 100 |
| | MQTT_CONFIG = { |
| | "broker": "47.254.33.128", |
| | "port": 1883, |
| | "topic": "x1/bugs", |
| | "username": "my", |
| | "password": "my123456" |
| | } |
| |
|
| | |
| | mqtt_client = None |
| | latest_image_info = {"path": None, "date": None, "objnum": None} |
| | image_history = [] |
| | mqtt_status = "<span style='color: red;'>MQTT Disconnected</span>" |
| | current_prompt = "all" |
| | current_task = "<OD>" |
| |
|
| | task_name = { |
| | "detect all objects": "<OD>", |
| | "detect by vocabulary": "<OPEN_VOCABULARY_DETECTION>", |
| | "detect by phrase": "<CAPTION_TO_PHRASE_GROUNDING>" |
| | } |
| |
|
| | |
| | os.makedirs(IMAGE_HISTORY_DIR, exist_ok=True) |
| |
|
| | |
| | def on_connect(client, userdata, flags, rc): |
| | global mqtt_status |
| | if rc == 0: |
| | client.subscribe(MQTT_CONFIG["topic"]) |
| | mqtt_status = "<span style='color: green;'>MQTT Connected</span>" |
| | else: |
| | mqtt_status = "<span style='color: red;'>MQTT Disconnected</span>" |
| |
|
| | def on_disconnect(client, userdata, rc): |
| | global mqtt_status |
| | mqtt_status = "<span style='color: red;'>MQTT Disconnected</span>" |
| |
|
| | def on_message(client, userdata, msg): |
| | threading.Thread(target=handle_message, args=(msg,)).start() |
| |
|
| | def handle_message(msg): |
| | try: |
| | print("Received message") |
| | data = json.loads(msg.payload) |
| | image_data = data["values"]["image"].split(",")[1] |
| | localtime = data["values"]["localtime"] |
| |
|
| | image = Image.open(BytesIO(base64.b64decode(image_data))) |
| | if image.mode == "RGBA": |
| | image = image.convert("RGB") |
| | image.save(IMAGE_PATH) |
| |
|
| | image_history_path = os.path.join(IMAGE_HISTORY_DIR, f"{localtime.replace(' ', '_').replace(':', '-')}.jpg") |
| | image.save(image_history_path) |
| |
|
| | prediction = predict_image_json(image, current_task, current_prompt) |
| | annotated_image_path = annotate_image(image, prediction, current_task) |
| | detected_objects = predicted_objects_num(prediction, current_task) |
| |
|
| | latest_image_info.update({ |
| | "path": annotated_image_path, |
| | "date": localtime, |
| | "objnum": detected_objects |
| | }) |
| |
|
| | image_history.append((image_history_path, localtime)) |
| | manage_history_size() |
| | except Exception as e: |
| | print(f"Error processing message: {e}") |
| |
|
| | def convert_to_od_format(data): |
| | bboxes = data.get('bboxes', []) |
| | labels = data.get('bboxes_labels', []) |
| | od_results = { |
| | 'bboxes': bboxes, |
| | 'labels': labels |
| | } |
| | return od_results |
| |
|
| | def predict_image_json(image, task, prompt): |
| | msgid = str(datetime.now().timestamp()) |
| | if task == "<OD>": |
| | prompt = "" |
| | prediction = predict_image(image, task, prompt) |
| | if task == "<OPEN_VOCABULARY_DETECTION>": |
| | prediction[task] = convert_to_od_format(prediction[task]) |
| | return prediction |
| |
|
| | def annotate_image(image, prediction, task): |
| | draw = ImageDraw.Draw(image) |
| | width, height = image.size |
| | scale = max(width, height) / 1000 |
| | font_size = int(30 * scale) |
| | line_width = int(3 * scale) |
| | try: |
| | font = ImageFont.truetype("DejaVuSans.ttf", font_size) |
| | except IOError: |
| | font = ImageFont.load_default() |
| |
|
| | for bbox, label in zip(prediction[task]["bboxes"], prediction[task]["labels"]): |
| | x1, y1, x2, y2 = bbox |
| | draw.rectangle([x1, y1, x2, y2], outline="yellow", width=line_width) |
| | text_bbox = draw.textbbox((x1, y1), label, font=font) |
| | draw.rectangle([text_bbox[0], text_bbox[1], text_bbox[2], text_bbox[3]], fill="black") |
| | draw.text((x1, y1), label, fill="white", font=font) |
| |
|
| | annotated_image_path = IMAGE_PATH.replace(".jpg", "_annotated.jpg") |
| | image.save(annotated_image_path) |
| | return annotated_image_path |
| |
|
| | def predicted_objects_num(prediction, task): |
| | return len(prediction[task]["bboxes"]) |
| |
|
| | def start_mqtt_client(broker, port, topic, username, password): |
| | global mqtt_client |
| | if mqtt_client is not None: |
| | mqtt_client.disconnect() |
| | mqtt_client = mqtt.Client() |
| | mqtt_client.username_pw_set(username, password) |
| | mqtt_client.on_connect = on_connect |
| | mqtt_client.on_disconnect = on_disconnect |
| | mqtt_client.on_message = on_message |
| | mqtt_client.connect(broker, port, 60) |
| | mqtt_client.loop_start() |
| |
|
| | def display_image(): |
| | print("Displaying latest image...") |
| | return latest_image_info["path"], latest_image_info["objnum"] |
| |
|
| | def display_image_history(): |
| | return [(path, date) for path, date in image_history] |
| |
|
| | def show_prediction_on_history(evt: gr.SelectData): |
| | image_path = image_history[int(evt.index)][0] |
| | image = Image.open(image_path) |
| | image.save(IMAGE_PATH) |
| | prediction = predict_image_json(image, current_task, current_prompt) |
| | annotated_image_path = annotate_image(image, prediction, current_task) |
| | predicted_objects = predicted_objects_num(prediction, current_task) |
| | latest_image_info["path"] = annotated_image_path |
| | latest_image_info["objnum"] = predicted_objects |
| | return annotated_image_path, predicted_objects |
| |
|
| | def update_mqtt_config(broker, port, topic, username, password): |
| | start_mqtt_client(broker, int(port), topic, username, password) |
| | return f"Connected to {broker}:{port}, subscribed to topic '{topic}'" |
| |
|
| | def auto_connect(): |
| | update_mqtt_config( |
| | MQTT_CONFIG["broker"], |
| | MQTT_CONFIG["port"], |
| | MQTT_CONFIG["topic"], |
| | MQTT_CONFIG["username"], |
| | MQTT_CONFIG["password"] |
| | ) |
| |
|
| | def history_image_load(): |
| | global image_history |
| | image_history = [] |
| | for filename in os.listdir(IMAGE_HISTORY_DIR): |
| | if filename.endswith(".jpg"): |
| | image_history.append((os.path.join(IMAGE_HISTORY_DIR, filename), filename.replace("_", " ").replace("-", ":"))) |
| | image_history.sort(key=lambda x: x[1]) |
| | manage_history_size() |
| |
|
| | def get_mqtt_status(): |
| | return mqtt_status |
| |
|
| | def upload_image(filepath): |
| | image = Image.open(filepath) |
| | if image.mode == "RGBA": |
| | image = image.convert("RGB") |
| | image.save(IMAGE_PATH) |
| | localtime = datetime.now().strftime('%Y-%m-%d %H:%M:%S') |
| | image_history_path = os.path.join(IMAGE_HISTORY_DIR, f"{localtime.replace(' ', '_').replace(':', '-')}.jpg") |
| | image.save(image_history_path) |
| | prediction = predict_image_json(image, current_task, current_prompt) |
| | annotated_image_path = annotate_image(image, prediction, current_task) |
| | predicted_objects = predicted_objects_num(prediction, current_task) |
| | latest_image_info.update({ |
| | "path": annotated_image_path, |
| | "date": localtime, |
| | "objnum": predicted_objects |
| | }) |
| | image_history.append((image_history_path, localtime)) |
| | manage_history_size() |
| | return annotated_image_path, predicted_objects, display_image_history() |
| |
|
| | def manage_history_size(): |
| | global image_history |
| | if len(image_history) > MAX_HISTORY_SIZE: |
| | for i in range(2): |
| | os.remove(image_history.pop(0)[0]) |
| |
|
| | def commit_prompt(prompt): |
| | global current_prompt |
| | print(f"Updating prompt to: {prompt}") |
| | if prompt == "": |
| | prompt = "all" |
| | current_prompt = prompt |
| | image = Image.open(IMAGE_PATH) |
| | prediction = predict_image_json(image, current_task, current_prompt) |
| | annotated_image_path = annotate_image(image, prediction, current_task) |
| | predicted_objects = predicted_objects_num(prediction, current_task) |
| | latest_image_info["path"] = annotated_image_path |
| | latest_image_info["objnum"] = predicted_objects |
| | return annotated_image_path, predicted_objects |
| |
|
| | def update_task(task, prompt): |
| | global current_task |
| | task = task_name[task] |
| | current_task = task |
| | if task == "<OD>": |
| | current_prompt = "" |
| | else: |
| | current_prompt = prompt |
| | print(f"Updating task to: {task}, prompt to: {current_prompt}") |
| | return gr.update(visible=task != "<OD>") |
| |
|
| | with gr.Blocks(css="footer {visibility: hidden}") as iface: |
| | gr.Markdown("## MS Computer Vision Demo") |
| | mqtt_status_output = gr.HTML(value=mqtt_status) |
| |
|
| | with gr.Accordion("MQTT Settings", open=False): |
| | with gr.Row(): |
| | broker_input = gr.Textbox(label="MQTT Broker", value=MQTT_CONFIG["broker"]) |
| | port_input = gr.Textbox(label="MQTT Port", value=str(MQTT_CONFIG["port"])) |
| | topic_input = gr.Textbox(label="MQTT Topic", value=MQTT_CONFIG["topic"]) |
| | with gr.Row(): |
| | username_input = gr.Textbox(label="MQTT Username", value=MQTT_CONFIG["username"]) |
| | password_input = gr.Textbox(label="MQTT Password", type="password", value=MQTT_CONFIG["password"]) |
| | connect_button = gr.Button("Connect") |
| | connect_button.click( |
| | fn=update_mqtt_config, |
| | inputs=[broker_input, port_input, topic_input, username_input, password_input], |
| | outputs=[] |
| | ) |
| |
|
| | with gr.Row(): |
| | with gr.Column(scale=2): |
| | image_output = gr.Image(label="Latest Image") |
| | detected_objects_output = gr.Textbox(label="Detected Objects Count", placeholder="No objects detected", interactive=False) |
| | task_input = gr.Dropdown( |
| | label="Task", |
| | choices=list(task_name.keys()), |
| | value="detect all objects" |
| | ) |
| | prompt_input = gr.Textbox(label="Prompt(Optional)", placeholder="what is object want to detect?", visible=False) |
| | task_input.change(fn=update_task, inputs=[task_input, prompt_input], outputs=[prompt_input]) |
| | commit_button = gr.Button("Commit") |
| | commit_button.click(fn=commit_prompt, inputs=[prompt_input], outputs=[image_output, detected_objects_output]) |
| | with gr.Column(scale=1): |
| | history_output = gr.Gallery(label="History Image", columns=3) |
| | upload_button = gr.UploadButton(label="Upload Image", file_types=["image"]) |
| | upload_button.upload(fn=upload_image, inputs=upload_button, outputs=[image_output, detected_objects_output, history_output]) |
| |
|
| | def refresh_interface(): |
| | return display_image() |
| |
|
| | def refresh_history(): |
| | return display_image_history() |
| |
|
| | history_output.change(fn=refresh_interface, outputs=[image_output, detected_objects_output]) |
| |
|
| | history_image_load() |
| | iface.load(fn=refresh_history, inputs=[], outputs=history_output, every=0.5) |
| |
|
| | auto_connect() |
| | iface.load(fn=get_mqtt_status, inputs=[], outputs=mqtt_status_output) |
| | history_output.select(fn=show_prediction_on_history, outputs=[image_output, detected_objects_output]) |
| |
|
| | iface.launch(share=True) |
| |
|