File size: 10,363 Bytes
1a0d68d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile, ReliabilityPolicy
from geometry_msgs.msg import Twist, Pose
from sensor_msgs.msg import LaserScan
import gymnasium as gym
from gymnasium import spaces
import numpy as np
import time
import os
import math

# --- KUANTUM VE PYTORCH KÜTÜPHANELERİ ---
import torch
import torch.nn as nn
import pennylane as qml
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor

# ====================================================
# 1. KUANTUM DEVRESİ VE KATMANI (QNN)
# ====================================================

n_qubits = 4               # 4 Qubit (Simülasyon hızı için ideal)
n_layers = 2               # Kuantum derinliği
dev = qml.device("default.qubit", wires=n_qubits) # Simülatör Cihazı

@qml.qnode(dev, interface="torch")
def quantum_circuit(inputs, weights):
    # Veriyi Qubitlere gömüyoruz (Angle Embedding)
    qml.templates.AngleEmbedding(inputs, wires=range(n_qubits))
    
    # Kuantum Ağırlıkları ve Dolanıklık (Entanglement)
    qml.templates.StronglyEntanglingLayers(weights, wires=range(n_qubits))
    
    # Ölçüm (Z ekseninde beklenen değer)
    return [qml.expval(qml.PauliZ(wires=i)) for i in range(n_qubits)]

class HybridQuantumLayer(nn.Module):
    def __init__(self):
        super().__init__()
        weight_shapes = {"weights": (n_layers, n_qubits, 3)}
        # PennyLane devresini PyTorch katmanına dönüştür
        self.q_layer = qml.qnn.TorchLayer(quantum_circuit, weight_shapes)
        
    def forward(self, x):
        return self.q_layer(x)

# ====================================================
# 2. ÖZEL ÖZELLİK ÇIKARICI (SB3 ENTEGRASYONU)
# ====================================================

class QuantumFeatureExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space: spaces.Box, features_dim: int = 4):
        super().__init__(observation_space, features_dim)
        
        # A) Klasik Ön İşleme (10 Lidar verisini 4 Qubit girişine indir)
        self.classical_pre_net = nn.Sequential(
            nn.Linear(10, 32),
            nn.ReLU(),
            nn.Linear(32, n_qubits), 
            nn.Tanh() # Girişleri -PI ile +PI arasına sıkıştır (Açı için şart)
        )
        
        # B) Kuantum İşleme
        self.quantum_net = HybridQuantumLayer()
        
    def forward(self, observations: torch.Tensor) -> torch.Tensor:
        # Klasik ağdan geçir
        x = self.classical_pre_net(observations)
        # Kuantum devresinden geçir
        x = self.quantum_net(x)
        return x

# ====================================================
# 3. ROBOT ORTAMI (RC CAR ENV)
# ====================================================

class RcCarEnv(gym.Env):
    def __init__(self):
        super(RcCarEnv, self).__init__()
        
        if not rclpy.ok():
            rclpy.init()
            
        self.node = rclpy.create_node('quantum_ackermann_agent')
        
        # QoS Ayarları
        qos_reliable = QoSProfile(reliability=ReliabilityPolicy.RELIABLE, depth=10)
        qos_best_effort = QoSProfile(reliability=ReliabilityPolicy.BEST_EFFORT, depth=10)
        
        self.pub_cmd = self.node.create_publisher(Twist, '/cmd_vel', 10)
        self.sub_scan = self.node.create_subscription(LaserScan, '/scan', self.scan_cb, qos_best_effort)
        
        # Parametreler
        self.max_speed = 0.5
        self.max_steering = 0.45 
        self.steering_smooth_factor = 0.2
        
        self.current_steering = 0.0
        self.scan_data = None
        self.scan_received = False
        
        # Action Space: [Gaz, Direksiyon]
        self.action_space = spaces.Box(low=np.array([0.0, -1.0]), high=np.array([1.0, 1.0]), dtype=np.float32)
        
        # Observation Space: 10 Lidar noktası
        self.n_obs = 10
        self.max_range = 10.0
        self.observation_space = spaces.Box(low=0.0, high=1.0, shape=(self.n_obs,), dtype=np.float32)
        
        self.step_count = 0
        self.max_steps = 2000

    def scan_cb(self, msg):
        raw = np.array(msg.ranges)
        raw = np.where(np.isinf(raw), self.max_range, raw)
        raw = np.where(np.isnan(raw), self.max_range, raw)
        raw = np.clip(raw, 0.0, self.max_range)
        
        chunk = len(raw) // self.n_obs
        if chunk > 0:
            self.scan_data = np.array([np.min(raw[i*chunk:(i+1)*chunk]) for i in range(self.n_obs)], dtype=np.float32)
        else:
            self.scan_data = np.full(self.n_obs, self.max_range, dtype=np.float32)
        self.scan_received = True

    def step(self, action):
        twist = Twist()
        
        # 1. Action'ı Uygula
        # Hız (Negatif hıza izin vermiyoruz, sadece ileri)
        # Action[0] normalde -1 ile 1 arasındadır. (action[0] + 1) / 2 yaparak 0-1 arasına çekiyoruz.
        # Böylece robot geri gidip kaçmayı öğrenmez, hep ileri bakar.
        forward_action = (float(action[0]) + 1.0) / 2.0 
        target_speed = forward_action * self.max_speed
        twist.linear.x = target_speed
        
        # Direksiyon
        target_steering = float(action[1]) * self.max_steering
        self.current_steering = (1.0 - self.steering_smooth_factor) * self.current_steering + \
                                (self.steering_smooth_factor * target_steering)
        twist.angular.z = self.current_steering
        
        self.pub_cmd.publish(twist)
        
        # 2. Veri Bekleme
        self.scan_received = False 
        start_wait = time.time()
        while not self.scan_received:
            rclpy.spin_once(self.node, timeout_sec=0.01)
            if time.time() - start_wait > 0.2: break
        
        if self.scan_data is not None:
            observation = self.scan_data / self.max_range
        else:
            observation = np.ones(self.n_obs, dtype=np.float32)
            
        self.step_count += 1
        
        # --- 3. ÖDÜL FONKSİYONU (ANTİ-TEMBELLİK GÜNCELLEMESİ) ---
        real_min_dist = np.min(observation) * self.max_range
        reward = 0.0
        
        # KURAL 1: Sadece Hızlı Gidiyorsa Hayatta Kalma Puanı Ver
        # Eskiden: Sadece nefes aldığı için puan alıyordu.
        # Şimdi: Eğer hızı 0.1 m/s üzerindeyse puan alır.
        if twist.linear.x > 0.1:
            reward += 0.1  # Hızlanmaya teşvik
            reward += twist.linear.x * 0.5 # Hızlandıkça artan ödül
        else:
            reward -= 0.05 # Durduğu yerde durursa ceza (Hadi kımılda!)

        # KURAL 2: Direksiyon Cezasını Artır (Düz gitmeyi sevdir)
        # Sürekli dönmek artık pahalı. Düz gitmek bedava.
        # Direksiyonu ne kadar kırarsa o kadar ceza.
        reward -= abs(self.current_steering) * 0.6  # Ceza katsayısını 0.1'den 0.6'ya çıkardık!
        
        # KURAL 3: Duvara Yaklaşma Cezası (Güvenlik)
        if real_min_dist < 0.6:
            reward -= (0.6 - real_min_dist) * 2.0

        terminated = False
        truncated = False
        
        # Çarpışma Kontrolü
        if real_min_dist < 0.25:
            reward = -20.0 # Cezayı biraz artırdık
            terminated = True
            print(f"💥 Bam! ({self.step_count} adım)")
            self.pub_cmd.publish(Twist()) # Dur
            
        if self.step_count >= self.max_steps:
            truncated = True
        
        return observation.astype(np.float32), reward, terminated, truncated, {}
    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        print("\n⚛️  Kuantum Reset (Direct Gazebo Call)...")
        
        self.current_steering = 0.0
        self.pub_cmd.publish(Twist())

        # ZORLA RESET (Direct Service Call)
        cmd = """gz service -s /world/arena/set_pose \
                 --reqtype gz.msgs.Pose \
                 --reptype gz.msgs.Boolean \
                 --timeout 300 \
                 --req 'name: "my_rc_car", position: {x: 0, y: 0, z: 0.2}, orientation: {w: 1}'"""
        os.system(cmd)
        
        self.scan_data = None
        self.scan_received = False
        start_wait = time.time()
        while time.time() - start_wait < 1.0:
            rclpy.spin_once(self.node, timeout_sec=0.1)
            
        if self.scan_data is not None:
            obs = self.scan_data / self.max_range
        else:
            obs = np.ones(self.n_obs, dtype=np.float32)
            
        self.step_count = 0
        return obs.astype(np.float32), {}

    def close(self):
        self.pub_cmd.publish(Twist())
        self.node.destroy_node()

# ====================================================
# 4. MAIN - EĞİTİMİ BAŞLAT
# ====================================================

if __name__ == '__main__':
    from stable_baselines3 import PPO
    
    log_dir = "./logs_quantum/"
    os.makedirs(log_dir, exist_ok=True)
    
    env = RcCarEnv()
    
    print("---------------------------------------")
    print("⚛️  HİBRİT KUANTUM RL EĞİTİMİ BAŞLIYOR")
    print("🧠  4 Qubit Devrede | PennyLane & PyTorch")
    print("❤️  Engin'i Seviyorum! Geleceği Kodluyoruz!")
    print("---------------------------------------")
    
    # Kuantum Feature Extractor'ı Policy'ye enjekte ediyoruz
    policy_kwargs = dict(
        features_extractor_class=QuantumFeatureExtractor,
        features_extractor_kwargs=dict(features_dim=n_qubits),
    )
    
    model = PPO(
        "MlpPolicy", 
        env, 
        verbose=1, 
        tensorboard_log=log_dir,
        learning_rate=3e-4,     # Kuantum ağları yavaş öğrenir, düşük LR iyidir
        policy_kwargs=policy_kwargs, # <--- KUANTUM KATMANI BURADA
        batch_size=32,          # Kuantum simülasyonu yavaş olduğu için batch'i küçülttüm
        n_steps=512,            # Update sıklığını artırdım
        gamma=0.99
    )
    
    try:
        # Kuantum simülasyonu yavaş olduğu için total_timesteps'i başta düşük tuttum
        # Memnun kalırsan artırabilirsin.
        model.learn(total_timesteps=50000)
        model.save("rc_car_quantum_master")
        print("✅ Kuantum Eğitimi Başarıyla Tamamlandı!")
    except KeyboardInterrupt:
        model.save("rc_car_quantum_interrupted")
        print("\n⚠️ Eğitim durduruldu ve kaydedildi.")
    finally:
        env.close()