car_env / quantum_training.py
Hajorda's picture
Upload folder using huggingface_hub
1a0d68d verified
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile, ReliabilityPolicy
from geometry_msgs.msg import Twist, Pose
from sensor_msgs.msg import LaserScan
import gymnasium as gym
from gymnasium import spaces
import numpy as np
import time
import os
import math
# --- KUANTUM VE PYTORCH KÜTÜPHANELERİ ---
import torch
import torch.nn as nn
import pennylane as qml
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
# ====================================================
# 1. KUANTUM DEVRESİ VE KATMANI (QNN)
# ====================================================
n_qubits = 4 # 4 Qubit (Simülasyon hızı için ideal)
n_layers = 2 # Kuantum derinliği
dev = qml.device("default.qubit", wires=n_qubits) # Simülatör Cihazı
@qml.qnode(dev, interface="torch")
def quantum_circuit(inputs, weights):
# Veriyi Qubitlere gömüyoruz (Angle Embedding)
qml.templates.AngleEmbedding(inputs, wires=range(n_qubits))
# Kuantum Ağırlıkları ve Dolanıklık (Entanglement)
qml.templates.StronglyEntanglingLayers(weights, wires=range(n_qubits))
# Ölçüm (Z ekseninde beklenen değer)
return [qml.expval(qml.PauliZ(wires=i)) for i in range(n_qubits)]
class HybridQuantumLayer(nn.Module):
def __init__(self):
super().__init__()
weight_shapes = {"weights": (n_layers, n_qubits, 3)}
# PennyLane devresini PyTorch katmanına dönüştür
self.q_layer = qml.qnn.TorchLayer(quantum_circuit, weight_shapes)
def forward(self, x):
return self.q_layer(x)
# ====================================================
# 2. ÖZEL ÖZELLİK ÇIKARICI (SB3 ENTEGRASYONU)
# ====================================================
class QuantumFeatureExtractor(BaseFeaturesExtractor):
def __init__(self, observation_space: spaces.Box, features_dim: int = 4):
super().__init__(observation_space, features_dim)
# A) Klasik Ön İşleme (10 Lidar verisini 4 Qubit girişine indir)
self.classical_pre_net = nn.Sequential(
nn.Linear(10, 32),
nn.ReLU(),
nn.Linear(32, n_qubits),
nn.Tanh() # Girişleri -PI ile +PI arasına sıkıştır (Açı için şart)
)
# B) Kuantum İşleme
self.quantum_net = HybridQuantumLayer()
def forward(self, observations: torch.Tensor) -> torch.Tensor:
# Klasik ağdan geçir
x = self.classical_pre_net(observations)
# Kuantum devresinden geçir
x = self.quantum_net(x)
return x
# ====================================================
# 3. ROBOT ORTAMI (RC CAR ENV)
# ====================================================
class RcCarEnv(gym.Env):
def __init__(self):
super(RcCarEnv, self).__init__()
if not rclpy.ok():
rclpy.init()
self.node = rclpy.create_node('quantum_ackermann_agent')
# QoS Ayarları
qos_reliable = QoSProfile(reliability=ReliabilityPolicy.RELIABLE, depth=10)
qos_best_effort = QoSProfile(reliability=ReliabilityPolicy.BEST_EFFORT, depth=10)
self.pub_cmd = self.node.create_publisher(Twist, '/cmd_vel', 10)
self.sub_scan = self.node.create_subscription(LaserScan, '/scan', self.scan_cb, qos_best_effort)
# Parametreler
self.max_speed = 0.5
self.max_steering = 0.45
self.steering_smooth_factor = 0.2
self.current_steering = 0.0
self.scan_data = None
self.scan_received = False
# Action Space: [Gaz, Direksiyon]
self.action_space = spaces.Box(low=np.array([0.0, -1.0]), high=np.array([1.0, 1.0]), dtype=np.float32)
# Observation Space: 10 Lidar noktası
self.n_obs = 10
self.max_range = 10.0
self.observation_space = spaces.Box(low=0.0, high=1.0, shape=(self.n_obs,), dtype=np.float32)
self.step_count = 0
self.max_steps = 2000
def scan_cb(self, msg):
raw = np.array(msg.ranges)
raw = np.where(np.isinf(raw), self.max_range, raw)
raw = np.where(np.isnan(raw), self.max_range, raw)
raw = np.clip(raw, 0.0, self.max_range)
chunk = len(raw) // self.n_obs
if chunk > 0:
self.scan_data = np.array([np.min(raw[i*chunk:(i+1)*chunk]) for i in range(self.n_obs)], dtype=np.float32)
else:
self.scan_data = np.full(self.n_obs, self.max_range, dtype=np.float32)
self.scan_received = True
def step(self, action):
twist = Twist()
# 1. Action'ı Uygula
# Hız (Negatif hıza izin vermiyoruz, sadece ileri)
# Action[0] normalde -1 ile 1 arasındadır. (action[0] + 1) / 2 yaparak 0-1 arasına çekiyoruz.
# Böylece robot geri gidip kaçmayı öğrenmez, hep ileri bakar.
forward_action = (float(action[0]) + 1.0) / 2.0
target_speed = forward_action * self.max_speed
twist.linear.x = target_speed
# Direksiyon
target_steering = float(action[1]) * self.max_steering
self.current_steering = (1.0 - self.steering_smooth_factor) * self.current_steering + \
(self.steering_smooth_factor * target_steering)
twist.angular.z = self.current_steering
self.pub_cmd.publish(twist)
# 2. Veri Bekleme
self.scan_received = False
start_wait = time.time()
while not self.scan_received:
rclpy.spin_once(self.node, timeout_sec=0.01)
if time.time() - start_wait > 0.2: break
if self.scan_data is not None:
observation = self.scan_data / self.max_range
else:
observation = np.ones(self.n_obs, dtype=np.float32)
self.step_count += 1
# --- 3. ÖDÜL FONKSİYONU (ANTİ-TEMBELLİK GÜNCELLEMESİ) ---
real_min_dist = np.min(observation) * self.max_range
reward = 0.0
# KURAL 1: Sadece Hızlı Gidiyorsa Hayatta Kalma Puanı Ver
# Eskiden: Sadece nefes aldığı için puan alıyordu.
# Şimdi: Eğer hızı 0.1 m/s üzerindeyse puan alır.
if twist.linear.x > 0.1:
reward += 0.1 # Hızlanmaya teşvik
reward += twist.linear.x * 0.5 # Hızlandıkça artan ödül
else:
reward -= 0.05 # Durduğu yerde durursa ceza (Hadi kımılda!)
# KURAL 2: Direksiyon Cezasını Artır (Düz gitmeyi sevdir)
# Sürekli dönmek artık pahalı. Düz gitmek bedava.
# Direksiyonu ne kadar kırarsa o kadar ceza.
reward -= abs(self.current_steering) * 0.6 # Ceza katsayısını 0.1'den 0.6'ya çıkardık!
# KURAL 3: Duvara Yaklaşma Cezası (Güvenlik)
if real_min_dist < 0.6:
reward -= (0.6 - real_min_dist) * 2.0
terminated = False
truncated = False
# Çarpışma Kontrolü
if real_min_dist < 0.25:
reward = -20.0 # Cezayı biraz artırdık
terminated = True
print(f"💥 Bam! ({self.step_count} adım)")
self.pub_cmd.publish(Twist()) # Dur
if self.step_count >= self.max_steps:
truncated = True
return observation.astype(np.float32), reward, terminated, truncated, {}
def reset(self, seed=None, options=None):
super().reset(seed=seed)
print("\n⚛️ Kuantum Reset (Direct Gazebo Call)...")
self.current_steering = 0.0
self.pub_cmd.publish(Twist())
# ZORLA RESET (Direct Service Call)
cmd = """gz service -s /world/arena/set_pose \
--reqtype gz.msgs.Pose \
--reptype gz.msgs.Boolean \
--timeout 300 \
--req 'name: "my_rc_car", position: {x: 0, y: 0, z: 0.2}, orientation: {w: 1}'"""
os.system(cmd)
self.scan_data = None
self.scan_received = False
start_wait = time.time()
while time.time() - start_wait < 1.0:
rclpy.spin_once(self.node, timeout_sec=0.1)
if self.scan_data is not None:
obs = self.scan_data / self.max_range
else:
obs = np.ones(self.n_obs, dtype=np.float32)
self.step_count = 0
return obs.astype(np.float32), {}
def close(self):
self.pub_cmd.publish(Twist())
self.node.destroy_node()
# ====================================================
# 4. MAIN - EĞİTİMİ BAŞLAT
# ====================================================
if __name__ == '__main__':
from stable_baselines3 import PPO
log_dir = "./logs_quantum/"
os.makedirs(log_dir, exist_ok=True)
env = RcCarEnv()
print("---------------------------------------")
print("⚛️ HİBRİT KUANTUM RL EĞİTİMİ BAŞLIYOR")
print("🧠 4 Qubit Devrede | PennyLane & PyTorch")
print("❤️ Engin'i Seviyorum! Geleceği Kodluyoruz!")
print("---------------------------------------")
# Kuantum Feature Extractor'ı Policy'ye enjekte ediyoruz
policy_kwargs = dict(
features_extractor_class=QuantumFeatureExtractor,
features_extractor_kwargs=dict(features_dim=n_qubits),
)
model = PPO(
"MlpPolicy",
env,
verbose=1,
tensorboard_log=log_dir,
learning_rate=3e-4, # Kuantum ağları yavaş öğrenir, düşük LR iyidir
policy_kwargs=policy_kwargs, # <--- KUANTUM KATMANI BURADA
batch_size=32, # Kuantum simülasyonu yavaş olduğu için batch'i küçülttüm
n_steps=512, # Update sıklığını artırdım
gamma=0.99
)
try:
# Kuantum simülasyonu yavaş olduğu için total_timesteps'i başta düşük tuttum
# Memnun kalırsan artırabilirsin.
model.learn(total_timesteps=50000)
model.save("rc_car_quantum_master")
print("✅ Kuantum Eğitimi Başarıyla Tamamlandı!")
except KeyboardInterrupt:
model.save("rc_car_quantum_interrupted")
print("\n⚠️ Eğitim durduruldu ve kaydedildi.")
finally:
env.close()