Spaces:
Sleeping
Sleeping
altawil
commited on
Update model_definition.py
Browse files- model_definition.py +213 -126
model_definition.py
CHANGED
|
@@ -1187,152 +1187,239 @@ class InterfuserModel(nn.Module):
|
|
| 1187 |
# ============================================================================
|
| 1188 |
# دوال مساعدة لتحميل النموذج
|
| 1189 |
# ============================================================================
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1190 |
|
| 1191 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1192 |
"""
|
| 1193 |
-
|
| 1194 |
-
|
| 1195 |
-
|
| 1196 |
-
|
| 1197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1198 |
|
| 1199 |
-
|
| 1200 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1201 |
"""
|
| 1202 |
try:
|
| 1203 |
-
|
| 1204 |
-
|
| 1205 |
-
|
| 1206 |
|
| 1207 |
-
#
|
| 1208 |
-
|
| 1209 |
-
|
| 1210 |
-
|
| 1211 |
-
|
| 1212 |
-
|
| 1213 |
-
|
| 1214 |
-
logging.warning("⚠️ تم إنشاء النموذج بأوزان عشوائية")
|
| 1215 |
-
else:
|
| 1216 |
-
logging.info("لم يتم تحديد مسار الأوزان، سيتم استخدام أوزان عشوائية")
|
| 1217 |
|
| 1218 |
-
# وضع النموذج في وضع التقييم
|
| 1219 |
model.eval()
|
|
|
|
| 1220 |
|
| 1221 |
return model
|
| 1222 |
|
| 1223 |
except Exception as e:
|
| 1224 |
-
logging.error(f"
|
| 1225 |
raise
|
| 1226 |
|
| 1227 |
|
| 1228 |
-
|
| 1229 |
-
|
| 1230 |
-
|
|
|
|
| 1231 |
|
| 1232 |
-
|
| 1233 |
-
|
| 1234 |
-
|
| 1235 |
|
| 1236 |
-
|
| 1237 |
-
|
| 1238 |
-
|
| 1239 |
-
|
| 1240 |
-
|
| 1241 |
-
|
| 1242 |
-
|
| 1243 |
-
"enc_depth": 6,
|
| 1244 |
-
"dec_depth": 6,
|
| 1245 |
-
"rgb_backbone_name": 'r50',
|
| 1246 |
-
"lidar_backbone_name": 'r18',
|
| 1247 |
-
"waypoints_pred_head": 'gru',
|
| 1248 |
-
"use_different_backbone": True,
|
| 1249 |
-
"with_lidar": False,
|
| 1250 |
-
"with_right_left_sensors": False,
|
| 1251 |
-
"with_center_sensor": False,
|
| 1252 |
|
| 1253 |
-
|
| 1254 |
-
|
| 1255 |
-
|
| 1256 |
-
|
| 1257 |
-
|
| 1258 |
-
|
| 1259 |
-
|
| 1260 |
-
|
| 1261 |
-
|
| 1262 |
-
|
| 1263 |
-
|
| 1264 |
-
|
| 1265 |
-
|
| 1266 |
-
|
| 1267 |
-
|
| 1268 |
-
|
| 1269 |
-
|
| 1270 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1271 |
|
| 1272 |
-
|
| 1273 |
-
|
|
|
|
| 1274 |
|
| 1275 |
-
|
| 1276 |
-
|
| 1277 |
-
|
| 1278 |
-
|
| 1279 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1280 |
|
| 1281 |
-
|
| 1282 |
-
|
| 1283 |
-
|
| 1284 |
-
|
| 1285 |
-
|
| 1286 |
-
|
| 1287 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1288 |
|
| 1289 |
-
|
| 1290 |
-
|
| 1291 |
-
|
| 1292 |
-
|
| 1293 |
-
|
| 1294 |
-
|
| 1295 |
-
|
| 1296 |
-
|
| 1297 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1298 |
|
| 1299 |
-
|
| 1300 |
-
|
| 1301 |
-
|
| 1302 |
-
def get_training_config():
|
| 1303 |
-
|
| 1304 |
-
|
| 1305 |
-
|
| 1306 |
-
|
| 1307 |
-
|
| 1308 |
-
|
| 1309 |
-
|
| 1310 |
-
|
| 1311 |
-
|
| 1312 |
-
|
| 1313 |
-
|
| 1314 |
-
|
| 1315 |
-
|
| 1316 |
-
|
| 1317 |
-
|
| 1318 |
-
|
| 1319 |
-
|
| 1320 |
-
|
| 1321 |
-
|
| 1322 |
-
|
| 1323 |
-
|
| 1324 |
-
|
| 1325 |
-
|
| 1326 |
-
|
| 1327 |
-
|
| 1328 |
-
|
| 1329 |
-
|
| 1330 |
-
|
| 1331 |
-
|
| 1332 |
-
|
| 1333 |
-
|
| 1334 |
-
|
| 1335 |
-
|
| 1336 |
-
|
| 1337 |
-
|
| 1338 |
-
|
|
|
|
| 1187 |
# ============================================================================
|
| 1188 |
# دوال مساعدة لتحميل النموذج
|
| 1189 |
# ============================================================================
|
| 1190 |
+
# ==============================================================================
|
| 1191 |
+
# ملف: config_and_loader.py
|
| 1192 |
+
# هذا هو المصدر الوحيد للحقيقة لجميع الإعدادات وعملية تحميل النموذج.
|
| 1193 |
+
# ==============================================================================
|
| 1194 |
|
| 1195 |
+
import logging
|
| 1196 |
+
import torch
|
| 1197 |
+
from interfuser_model import Interfuser # تأكد من أن اسم ملف النموذج صحيح
|
| 1198 |
+
|
| 1199 |
+
def get_master_config(model_path="model/best_model.pth"):
|
| 1200 |
"""
|
| 1201 |
+
[النسخة الكاملة والنهائية]
|
| 1202 |
+
ينشئ ويدمج كل الإعدادات المطلوبة للتطبيق (النموذج، المتتبع، المتحكم).
|
| 1203 |
+
"""
|
| 1204 |
+
model_params = {
|
| 1205 |
+
"img_size": 224, "embed_dim": 256, "enc_depth": 6, "dec_depth": 6,
|
| 1206 |
+
"rgb_backbone_name": 'r50', "lidar_backbone_name": 'r18',
|
| 1207 |
+
"waypoints_pred_head": 'gru', "use_different_backbone": True,
|
| 1208 |
+
"with_lidar": False, "with_right_left_sensors": False,
|
| 1209 |
+
"with_center_sensor": False, "multi_view_img_size": 112,
|
| 1210 |
+
"patch_size": 8, "in_chans": 3, "dim_feedforward": 2048,
|
| 1211 |
+
"normalize_before": False, "num_heads": 8, "dropout": 0.1,
|
| 1212 |
+
"end2end": False, "direct_concat": False, "separate_view_attention": False,
|
| 1213 |
+
"separate_all_attention": False, "freeze_num": -1,
|
| 1214 |
+
"traffic_pred_head_type": "det", "reverse_pos": True,
|
| 1215 |
+
"use_view_embed": False, "use_mmad_pretrain": None,
|
| 1216 |
+
}
|
| 1217 |
+
|
| 1218 |
+
grid_conf = {
|
| 1219 |
+
'h': 20, 'w': 20, 'x_res': 1.0, 'y_res': 1.0,
|
| 1220 |
+
'y_min': 0.0, 'y_max': 20.0, 'x_min': -10.0, 'x_max': 10.0,
|
| 1221 |
+
}
|
| 1222 |
+
|
| 1223 |
+
controller_params = {
|
| 1224 |
+
'turn_KP': 0.75, 'turn_KI': 0.05, 'turn_KD': 0.25, 'turn_n': 20,
|
| 1225 |
+
'speed_KP': 0.55, 'speed_KI': 0.05, 'speed_KD': 0.15, 'speed_n': 20,
|
| 1226 |
+
'max_speed': 8.0, 'max_throttle': 0.75, 'min_speed': 0.1,
|
| 1227 |
+
'brake_sensitivity': 0.3, 'light_threshold': 0.5, 'stop_threshold': 0.6,
|
| 1228 |
+
'stop_sign_duration': 20, 'max_stop_time': 250,
|
| 1229 |
+
'forced_move_duration': 20, 'forced_throttle': 0.5,
|
| 1230 |
+
'max_red_light_time': 150, 'red_light_block_duration': 80,
|
| 1231 |
+
'accel_rate': 0.1, 'decel_rate': 0.2, 'critical_distance': 4.0,
|
| 1232 |
+
'follow_distance': 10.0, 'speed_match_factor': 0.9,
|
| 1233 |
+
'tracker_match_thresh': 2.5, 'tracker_prune_age': 5,
|
| 1234 |
+
'follow_grace_period': 20
|
| 1235 |
+
}
|
| 1236 |
+
|
| 1237 |
+
master_config = {
|
| 1238 |
+
'model_params': model_params,
|
| 1239 |
+
'grid_conf': grid_conf,
|
| 1240 |
+
'controller_params': controller_params,
|
| 1241 |
+
'paths': {'pretrained_weights': model_path},
|
| 1242 |
+
'simulation': {'frequency': 10.0}
|
| 1243 |
+
}
|
| 1244 |
|
| 1245 |
+
return master_config
|
| 1246 |
+
|
| 1247 |
+
|
| 1248 |
+
def load_and_prepare_model(device: torch.device) -> Interfuser:
|
| 1249 |
+
"""
|
| 1250 |
+
[النسخة النهائية الصحيحة - تستقبل مدخلاً واحدًا فقط]
|
| 1251 |
+
تستخدم دالة الإعدادات الرئيسية لإنشاء وتحميل النموذج.
|
| 1252 |
"""
|
| 1253 |
try:
|
| 1254 |
+
logging.info("Attempting to load model using master config...")
|
| 1255 |
+
# 1. الحصول على كل الإعدادات من المصدر الوحيد للحقيقة
|
| 1256 |
+
config = get_master_config()
|
| 1257 |
|
| 1258 |
+
# 2. إنشاء النموذج باستخدام إعدادات النموذج فقط
|
| 1259 |
+
model = Interfuser(**config['model_params']).to(device)
|
| 1260 |
+
logging.info(f"Model instantiated on device: {device}")
|
| 1261 |
+
|
| 1262 |
+
# 3. تحميل الأوزان باستخدام الدالة الداخلية للنموذج
|
| 1263 |
+
checkpoint_path = config['paths']['pretrained_weights']
|
| 1264 |
+
model.load_pretrained(checkpoint_path, strict=False)
|
|
|
|
|
|
|
|
|
|
| 1265 |
|
| 1266 |
+
# 4. وضع النموذج في وضع التقييم
|
| 1267 |
model.eval()
|
| 1268 |
+
logging.info("✅ Model prepared and set to evaluation mode.")
|
| 1269 |
|
| 1270 |
return model
|
| 1271 |
|
| 1272 |
except Exception as e:
|
| 1273 |
+
logging.error(f"❌ CRITICAL ERROR in load_and_prepare_model: {e}", exc_info=True)
|
| 1274 |
raise
|
| 1275 |
|
| 1276 |
|
| 1277 |
+
|
| 1278 |
+
# def load_and_prepare_model(config, device):
|
| 1279 |
+
# """
|
| 1280 |
+
# يقوم بإنشاء النموذج وتحميل الأوزان المدربة مسبقًا.
|
| 1281 |
|
| 1282 |
+
# Args:
|
| 1283 |
+
# config (dict): إعدادات النموذج والمسارات
|
| 1284 |
+
# device (torch.device): الجهاز المستهدف (CPU/GPU)
|
| 1285 |
|
| 1286 |
+
# Returns:
|
| 1287 |
+
# InterfuserModel: النموذج المحمل
|
| 1288 |
+
# """
|
| 1289 |
+
# try:
|
| 1290 |
+
# # إنشاء النموذج
|
| 1291 |
+
# model = InterfuserModel(**config.get('model_params', {})).to(device)
|
| 1292 |
+
# logging.info(f"تم إنشاء النموذج على الجهاز: {device}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1293 |
|
| 1294 |
+
# # تحميل الأوزان إذا كان المسار محدد
|
| 1295 |
+
# checkpoint_path = config.get('paths', {}).get('pretrained_weights')
|
| 1296 |
+
# if checkpoint_path:
|
| 1297 |
+
# success = model.load_pretrained(checkpoint_path, strict=False)
|
| 1298 |
+
# if success:
|
| 1299 |
+
# logging.info("✅ تم تحميل النموذج والأوزان بنجاح")
|
| 1300 |
+
# else:
|
| 1301 |
+
# logging.warning("⚠️ تم إنشاء النموذج بأوزان عشوائية")
|
| 1302 |
+
# else:
|
| 1303 |
+
# logging.info("لم يتم تحديد مسار الأوزان، سيتم استخدام أوزان عشوائية")
|
| 1304 |
+
|
| 1305 |
+
# # وضع النموذج في وضع التقييم
|
| 1306 |
+
# model.eval()
|
| 1307 |
+
|
| 1308 |
+
# return model
|
| 1309 |
+
|
| 1310 |
+
# except Exception as e:
|
| 1311 |
+
# logging.error(f"خطأ في إنشاء النموذج: {str(e)}")
|
| 1312 |
+
# raise
|
| 1313 |
+
|
| 1314 |
+
|
| 1315 |
+
# def create_model_config(model_path="model/best_model.pth", **model_params):
|
| 1316 |
+
# """
|
| 1317 |
+
# إنشاء إعدادات النموذج باستخدام الإعدادات الصحيحة من التدريب
|
| 1318 |
|
| 1319 |
+
# Args:
|
| 1320 |
+
# model_path (str): مسار ملف الأوزان
|
| 1321 |
+
# **model_params: معاملات النموذج الإضافية
|
| 1322 |
|
| 1323 |
+
# Returns:
|
| 1324 |
+
# dict: إعدادات النموذج
|
| 1325 |
+
# """
|
| 1326 |
+
# # الإعدادات الصحيحة من كونفيج التدريب الأصلي
|
| 1327 |
+
# training_config_params = {
|
| 1328 |
+
# "img_size": 224,
|
| 1329 |
+
# "embed_dim": 256, # مهم: هذه القيمة من التدريب الأصلي
|
| 1330 |
+
# "enc_depth": 6,
|
| 1331 |
+
# "dec_depth": 6,
|
| 1332 |
+
# "rgb_backbone_name": 'r50',
|
| 1333 |
+
# "lidar_backbone_name": 'r18',
|
| 1334 |
+
# "waypoints_pred_head": 'gru',
|
| 1335 |
+
# "use_different_backbone": True,
|
| 1336 |
+
# "with_lidar": False,
|
| 1337 |
+
# "with_right_left_sensors": False,
|
| 1338 |
+
# "with_center_sensor": False,
|
| 1339 |
|
| 1340 |
+
# # إعدادات إضافية من الكونفيج الأصلي
|
| 1341 |
+
# "multi_view_img_size": 112,
|
| 1342 |
+
# "patch_size": 8,
|
| 1343 |
+
# "in_chans": 3,
|
| 1344 |
+
# "dim_feedforward": 2048,
|
| 1345 |
+
# "normalize_before": False,
|
| 1346 |
+
# "num_heads": 8,
|
| 1347 |
+
# "dropout": 0.1,
|
| 1348 |
+
# "end2end": False,
|
| 1349 |
+
# "direct_concat": False,
|
| 1350 |
+
# "separate_view_attention": False,
|
| 1351 |
+
# "separate_all_attention": False,
|
| 1352 |
+
# "freeze_num": -1,
|
| 1353 |
+
# "traffic_pred_head_type": "det",
|
| 1354 |
+
# "reverse_pos": True,
|
| 1355 |
+
# "use_view_embed": False,
|
| 1356 |
+
# "use_mmad_pretrain": None,
|
| 1357 |
+
# }
|
| 1358 |
+
|
| 1359 |
+
# # دمج المعاملات المخصصة مع الإعدادات من التدريب
|
| 1360 |
+
# training_config_params.update(model_params)
|
| 1361 |
+
|
| 1362 |
+
# config = {
|
| 1363 |
+
# 'model_params': training_config_params,
|
| 1364 |
+
# 'paths': {
|
| 1365 |
+
# 'pretrained_weights': model_path
|
| 1366 |
+
# },
|
| 1367 |
|
| 1368 |
+
# # إضافة إعدادات الشبكة من التدريب
|
| 1369 |
+
# 'grid_conf': {
|
| 1370 |
+
# 'h': 20, 'w': 20,
|
| 1371 |
+
# 'x_res': 1.0, 'y_res': 1.0,
|
| 1372 |
+
# 'y_min': 0.0, 'y_max': 20.0,
|
| 1373 |
+
# 'x_min': -10.0, 'x_max': 10.0,
|
| 1374 |
+
# },
|
| 1375 |
+
|
| 1376 |
+
# # معلومات إضافية عن التدريب
|
| 1377 |
+
# 'training_info': {
|
| 1378 |
+
# 'original_project': 'Interfuser_Finetuning',
|
| 1379 |
+
# 'run_name': 'Finetune_Focus_on_Detection_v5',
|
| 1380 |
+
# 'focus': 'traffic_detection_and_iou',
|
| 1381 |
+
# 'backbone': 'ResNet50 + ResNet18',
|
| 1382 |
+
# 'trained_on': 'PDM_Lite_Carla'
|
| 1383 |
+
# }
|
| 1384 |
+
# }
|
| 1385 |
|
| 1386 |
+
# return config
|
| 1387 |
+
|
| 1388 |
+
|
| 1389 |
+
# def get_training_config():
|
| 1390 |
+
# """
|
| 1391 |
+
# إرجاع إعدادات التدريب الأصلية للمرجع
|
| 1392 |
+
# هذه الإعدادات توضح كيف تم تدريب النموذج
|
| 1393 |
+
# """
|
| 1394 |
+
# return {
|
| 1395 |
+
# 'project_info': {
|
| 1396 |
+
# 'project': 'Interfuser_Finetuning',
|
| 1397 |
+
# 'entity': None,
|
| 1398 |
+
# 'run_name': 'Finetune_Focus_on_Detection_v5'
|
| 1399 |
+
# },
|
| 1400 |
+
# 'training': {
|
| 1401 |
+
# 'epochs': 50,
|
| 1402 |
+
# 'batch_size': 8,
|
| 1403 |
+
# 'num_workers': 2,
|
| 1404 |
+
# 'learning_rate': 1e-4, # معدل تعلم منخفض للـ Fine-tuning
|
| 1405 |
+
# 'weight_decay': 1e-2,
|
| 1406 |
+
# 'patience': 15,
|
| 1407 |
+
# 'clip_grad_norm': 1.0,
|
| 1408 |
+
# },
|
| 1409 |
+
# 'loss_weights': {
|
| 1410 |
+
# 'iou': 2.0, # أولوية قصوى لدقة الصناديق
|
| 1411 |
+
# 'traffic_map': 25.0, # تركيز عالي على اكتشاف الكائنات
|
| 1412 |
+
# 'waypoints': 1.0, # مرجع أساسي
|
| 1413 |
+
# 'junction': 0.25, # مهام متقنة بالفعل
|
| 1414 |
+
# 'traffic_light': 0.5,
|
| 1415 |
+
# 'stop_sign': 0.25,
|
| 1416 |
+
# },
|
| 1417 |
+
# 'data_split': {
|
| 1418 |
+
# 'strategy': 'interleaved',
|
| 1419 |
+
# 'segment_length': 100,
|
| 1420 |
+
# 'validation_frequency': 10,
|
| 1421 |
+
# },
|
| 1422 |
+
# 'transforms': {
|
| 1423 |
+
# 'use_data_augmentation': False, # معطل للتركيز على البيانات الأصلية
|
| 1424 |
+
# }
|
| 1425 |
+
# }
|