Basee_model / app.py
mohammed-aljafry's picture
Upload app.py with huggingface_hub
2ef5d40 verified
raw
history blame
11.9 kB
# app.py
import os
import json
import traceback
import torch
import gradio as gr
import numpy as np
from PIL import Image
import cv2
import math
import logging
# ==============================================================================
# 1. إعداد الاستيرادات والإعدادات الأساسية
# ==============================================================================
# إعداد بسيط لعرض الرسائل الإعلامية والأخطاء
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# --- استيراد من وحدات المشروع المنظمة ---
try:
from model_definition import create_model_config, load_and_prepare_model
except ImportError:
raise ImportError("فشل استيراد من 'model.architecture'. تأكد من وجود الملف وأن مسار بايثون صحيح.")
try:
from simulation_modules import (
InterfuserController, ControllerConfig, DisplayInterface,
render, render_waypoints, render_self_car, ensure_rgb,
WAYPOINT_SCALE_FACTOR, T1_FUTURE_TIME, T2_FUTURE_TIME,
transform, lidar_transform, Tracker # تأكد من وجود transform هنا
)
except ImportError:
raise ImportError("فشل استيراد من 'simulation_modules'. تأكد من وجود الملف وأن مسار بايثون صحيح.")
# --- إعدادات ومسارات النماذج ---
WEIGHTS_DIR = "model"
EXAMPLES_DIR = "examples"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def find_available_models():
if not os.path.isdir(WEIGHTS_DIR): return []
return [f.replace(".pth", "") for f in os.listdir(WEIGHTS_DIR) if f.endswith(".pth")]
# ==============================================================================
# 2. الدوال الأساسية (load_model, run_single_frame)
# ==============================================================================
def load_model(model_name: str):
"""
(نسخة مبسطة)
تستخدم الآن الدوال المساعدة الجديدة لإنشاء وتحميل النموذج.
"""
if not model_name or "لم يتم" in model_name:
return None, "الرجاء اختيار نموذج صالح."
weights_path = os.path.join(WEIGHTS_DIR, f"{model_name}.pth")
# 1. إنشاء إعدادات النموذج المتوافقة مع الأوزان
config = create_model_config(model_path=weights_path)
# 2. إنشاء وتحميل النموذج بخطوة واحدة
try:
model = load_and_prepare_model(config, device)
status_message = f"تم تحميل نموذج: {model_name}"
if model is None: raise RuntimeError("فشلت دالة load_and_prepare_model")
except Exception as e:
model = None
status_message = f"فشل تحميل النموذج: {e}"
logging.error(traceback.format_exc())
return model, status_message
def run_single_frame(
model_from_state, rgb_image_path, rgb_left_image_path, rgb_right_image_path,
rgb_center_image_path, lidar_image_path, measurements_path, target_point_list
):
"""
تعتمد الآن على الوحدات المستوردة بشكل كامل.
"""
if model_from_state is None:
print("API session detected or model not loaded. Loading default model...")
available_models = find_available_models()
if not available_models: raise gr.Error("لا توجد نماذج متاحة للتحميل.")
model_to_use, _ = load_model(available_models[0])
else:
model_to_use = model_from_state
if model_to_use is None:
raise gr.Error("فشل تحميل النموذج. تحقق من السجلات (Logs).")
try:
if not (rgb_image_path and measurements_path):
raise gr.Error("الرجاء توفير الصورة الأمامية وملف القياسات على الأقل.")
# --- 1. قراءة ومعالجة المدخلات ---
rgb_image_pil = Image.open(rgb_image_path).convert("RGB")
rgb_left_pil = Image.open(rgb_left_image_path).convert("RGB") if rgb_left_image_path else rgb_image_pil
rgb_right_pil = Image.open(rgb_right_image_path).convert("RGB") if rgb_right_image_path else rgb_image_pil
rgb_center_pil = Image.open(rgb_center_image_path).convert("RGB") if rgb_center_image_path else rgb_image_pil
front_tensor = transform(rgb_image_pil).unsqueeze(0).to(device)
left_tensor = transform(rgb_left_pil).unsqueeze(0).to(device)
right_tensor = transform(rgb_right_pil).unsqueeze(0).to(device)
center_tensor = transform(rgb_center_pil).unsqueeze(0).to(device)
if lidar_image_path:
lidar_array = np.load(lidar_image_path)
if lidar_array.max() > 0: lidar_array = (lidar_array / lidar_array.max()) * 255.0
lidar_pil = Image.fromarray(lidar_array.astype(np.uint8)).convert('RGB')
else:
lidar_pil = Image.fromarray(np.zeros((112, 112, 3), dtype=np.uint8))
lidar_tensor = lidar_transform(lidar_pil).unsqueeze(0).to(device)
with open(measurements_path, 'r') as f: m_dict = json.load(f)
measurements_tensor = torch.tensor([[
m_dict.get('x',0.0), m_dict.get('y',0.0), m_dict.get('theta',0.0), m_dict.get('speed',5.0),
m_dict.get('steer',0.0), m_dict.get('throttle',0.0), float(m_dict.get('brake',0.0)),
m_dict.get('command',2.0), float(m_dict.get('is_junction',0.0)), float(m_dict.get('should_brake',0.0))
]], dtype=torch.float32).to(device)
target_point_tensor = torch.tensor([target_point_list], dtype=torch.float32).to(device)
inputs = {'rgb': front_tensor, 'rgb_left': left_tensor, 'rgb_right': right_tensor, 'rgb_center': center_tensor, 'lidar': lidar_tensor, 'measurements': measurements_tensor, 'target_point': target_point_tensor}
# --- 2. تشغيل النموذج ---
with torch.no_grad():
outputs = model_to_use(inputs)
traffic, waypoints, is_junction, traffic_light, stop_sign, _ = outputs
# --- 3. المعالجة اللاحقة والتصوّر ---
speed = m_dict.get('speed', 5.0)
controller = InterfuserController(ControllerConfig())
steer, throttle, brake, metadata_str = controller.run_step(speed, waypoints, is_junction.sigmoid()[0,1].item(), traffic_light.sigmoid()[0,0].item(), stop_sign.sigmoid()[0,1].item(), {})
map_t0, _ = render(traffic[0])
map_t1, _ = render(traffic[0], t=T1_FUTURE_TIME)
map_t2, _ = render(traffic[0], t=T2_FUTURE_TIME)
wp_map = render_waypoints(waypoints[0])
map_t0 = cv2.add(map_t0, wp_map)
map_t0 = render_self_car(map_t0)
display = DisplayInterface()
interface_data = {'camera_view': np.array(rgb_image_pil),'map_t0': map_t0,'map_t1': map_t1,'map_t2': map_t2,
'text_info': {'Control': f"S:{steer:.2f} T:{throttle:.2f} B:{int(brake)}", 'Metadata': metadata_str}}
dashboard_image = display.run_interface(interface_data)
# --- 4. تجهيز المخرجات ---
control_commands_dict = {"steer": steer, "throttle": throttle, "brake": bool(brake)}
return Image.fromarray(dashboard_image), control_commands_dict
except Exception as e:
logging.error(traceback.format_exc())
raise gr.Error(f"حدث خطأ أثناء معالجة الإطار: {e}")
# ==============================================================================
# 3. تعريف واجهة Gradio (لا تغيير هنا)
# ==============================================================================
available_models = find_available_models()
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"), css=".gradio-container {max-width: 95% !important;}") as demo:
model_state = gr.State(value=None)
gr.Markdown("# 🚗 محاكاة القيادة الذاتية باستخدام Interfuser")
gr.Markdown("مرحباً بك في واجهة اختبار نموذج Interfuser.")
with gr.Row():
with gr.Column(scale=1):
with gr.Group():
gr.Markdown("## ⚙️ الخطوة 1: اختر النموذج")
with gr.Row():
model_selector = gr.Dropdown(label="النماذج المتاحة", choices=available_models, value=available_models[0] if available_models else "لم يتم العثور على نماذج")
status_textbox = gr.Textbox(label="حالة النموذج", interactive=False)
with gr.Group():
gr.Markdown("## 🗂️ الخطوة 2: ارفع ملفات السيناريو")
with gr.Group():
gr.Markdown("**(مطلوب)**")
api_rgb_image_path = gr.File(label="صورة الكاميرا الأمامية (RGB)", type="filepath")
api_measurements_path = gr.File(label="ملف القياسات (JSON)", type="filepath")
with gr.Accordion("📷 مدخلات اختيارية", open=False):
api_rgb_left_image_path = gr.File(label="كاميرا اليسار (RGB)", type="filepath")
api_rgb_right_image_path = gr.File(label="كاميرا اليمين (RGB)", type="filepath")
api_rgb_center_image_path = gr.File(label="كاميرا الوسط (RGB)", type="filepath")
api_lidar_image_path = gr.File(label="بيانات الليدار (NPY)", type="filepath")
api_target_point_list = gr.JSON(label="📍 النقطة المستهدفة (x, y)", value=[0.0, 100.0])
api_run_button = gr.Button("🚀 شغل المحاكاة", variant="primary", scale=2)
with gr.Group():
gr.Markdown("### ✨ أمثلة جاهزة")
gr.Examples(
examples=[
[os.path.join(EXAMPLES_DIR, "sample1", "rgb.jpg"), os.path.join(EXAMPLES_DIR, "sample1", "measurements.json")],
[os.path.join(EXAMPLES_DIR, "sample2", "rgb.jpg"), os.path.join(EXAMPLES_DIR, "sample1", "measurements.json")]
],
inputs=[api_rgb_image_path, api_measurements_path], label="اختر سيناريو اختبار")
with gr.Column(scale=2):
with gr.Group():
gr.Markdown("## 📊 الخطوة 3: شاهد النتائج")
api_output_image = gr.Image(label="لوحة التحكم المرئية (Dashboard)", type="pil", interactive=False)
api_control_json = gr.JSON(label="أوامر التحكم (JSON)")
if available_models:
demo.load(fn=load_model, inputs=model_selector, outputs=[model_state, status_textbox])
model_selector.change(fn=load_model, inputs=model_selector, outputs=[model_state, status_textbox])
api_run_button.click(
fn=run_single_frame,
inputs=[model_state, api_rgb_image_path, api_rgb_left_image_path, api_rgb_right_image_path,
api_rgb_center_image_path, api_lidar_image_path, api_measurements_path, api_target_point_list],
outputs=[api_output_image, api_control_json],
api_name="run_single_frame"
)
# ==============================================================================
# 4. تشغيل التطبيق
# ==============================================================================
if __name__ == "__main__":
if not available_models:
logging.warning("لم يتم العثور على أي ملفات نماذج (.pth) في مجلد 'model/weights'.")
demo.queue().launch(debug=True, share=True)