hank9999 commited on
Commit
4327a43
·
1 Parent(s): c2bdb0a

feat: 添加 ping 保活机制,定时发送 SSE ping 事件

Browse files
Files changed (1) hide show
  1. src/anthropic/handlers.rs +66 -45
src/anthropic/handlers.rs CHANGED
@@ -12,6 +12,8 @@ use axum::{
12
  use bytes::Bytes;
13
  use futures::{stream, Stream, StreamExt};
14
  use serde_json::json;
 
 
15
  use uuid::Uuid;
16
  use crate::anthropic::token;
17
  use crate::kiro::model::events::Event;
@@ -206,6 +208,14 @@ async fn handle_stream_request(
206
  .unwrap()
207
  }
208
 
 
 
 
 
 
 
 
 
209
  /// 创建 SSE 事件流
210
  fn create_sse_stream(
211
  response: reqwest::Response,
@@ -219,63 +229,74 @@ fn create_sse_stream(
219
  .map(|e| Ok(Bytes::from(e.to_sse_string()))),
220
  );
221
 
222
- // 然后处理 Kiro 响应流
223
  let body_stream = response.bytes_stream();
224
 
225
  let processing_stream = stream::unfold(
226
- (body_stream, ctx, EventStreamDecoder::new(), false),
227
- |(mut body_stream, mut ctx, mut decoder, finished)| async move {
228
  if finished {
229
  return None;
230
  }
231
 
232
- // 尝试获取下一个 chunk
233
- match body_stream.next().await {
234
- Some(Ok(chunk)) => {
235
- // 解码事件
236
- decoder.feed(&chunk);
237
-
238
- let mut events = Vec::new();
239
- for result in decoder.decode_iter() {
240
- match result {
241
- Ok(frame) => {
242
- if let Ok(event) = Event::from_frame(frame) {
243
- let sse_events = ctx.process_kiro_event(&event);
244
- events.extend(sse_events);
 
 
 
 
 
 
 
 
245
  }
246
  }
247
- Err(e) => {
248
- tracing::warn!("解码事件失败: {}", e);
249
- }
250
- }
251
- }
252
 
253
- // 转换为 SSE 字节流
254
- let bytes: Vec<Result<Bytes, Infallible>> = events
255
- .into_iter()
256
- .map(|e| Ok(Bytes::from(e.to_sse_string())))
257
- .collect();
258
 
259
- Some((stream::iter(bytes), (body_stream, ctx, decoder, false)))
260
- }
261
- Some(Err(e)) => {
262
- tracing::error!("读取响应流失败: {}", e);
263
- // 发送最终事件并结束
264
- let final_events = ctx.generate_final_events();
265
- let bytes: Vec<Result<Bytes, Infallible>> = final_events
266
- .into_iter()
267
- .map(|e| Ok(Bytes::from(e.to_sse_string())))
268
- .collect();
269
- Some((stream::iter(bytes), (body_stream, ctx, decoder, true)))
 
 
 
 
 
 
 
 
 
 
 
270
  }
271
- None => {
272
- // 流结束,发送最终事件
273
- let final_events = ctx.generate_final_events();
274
- let bytes: Vec<Result<Bytes, Infallible>> = final_events
275
- .into_iter()
276
- .map(|e| Ok(Bytes::from(e.to_sse_string())))
277
- .collect();
278
- Some((stream::iter(bytes), (body_stream, ctx, decoder, true)))
279
  }
280
  }
281
  },
 
12
  use bytes::Bytes;
13
  use futures::{stream, Stream, StreamExt};
14
  use serde_json::json;
15
+ use std::time::Duration;
16
+ use tokio::time::interval;
17
  use uuid::Uuid;
18
  use crate::anthropic::token;
19
  use crate::kiro::model::events::Event;
 
208
  .unwrap()
209
  }
210
 
211
+ /// Ping 事件间隔(25秒)
212
+ const PING_INTERVAL_SECS: u64 = 25;
213
+
214
+ /// 创建 ping 事件的 SSE 字符串
215
+ fn create_ping_sse() -> Bytes {
216
+ Bytes::from("event: ping\ndata: {\"type\": \"ping\"}\n\n")
217
+ }
218
+
219
  /// 创建 SSE 事件流
220
  fn create_sse_stream(
221
  response: reqwest::Response,
 
229
  .map(|e| Ok(Bytes::from(e.to_sse_string()))),
230
  );
231
 
232
+ // 然后处理 Kiro 响应流,同时每25秒发送 ping 保活
233
  let body_stream = response.bytes_stream();
234
 
235
  let processing_stream = stream::unfold(
236
+ (body_stream, ctx, EventStreamDecoder::new(), false, interval(Duration::from_secs(PING_INTERVAL_SECS))),
237
+ |(mut body_stream, mut ctx, mut decoder, finished, mut ping_interval)| async move {
238
  if finished {
239
  return None;
240
  }
241
 
242
+ // 使用 select! 同时等待数据和 ping 定时器
243
+ tokio::select! {
244
+ // 处理数据流
245
+ chunk_result = body_stream.next() => {
246
+ match chunk_result {
247
+ Some(Ok(chunk)) => {
248
+ // 解码事件
249
+ decoder.feed(&chunk);
250
+
251
+ let mut events = Vec::new();
252
+ for result in decoder.decode_iter() {
253
+ match result {
254
+ Ok(frame) => {
255
+ if let Ok(event) = Event::from_frame(frame) {
256
+ let sse_events = ctx.process_kiro_event(&event);
257
+ events.extend(sse_events);
258
+ }
259
+ }
260
+ Err(e) => {
261
+ tracing::warn!("解码事件失败: {}", e);
262
+ }
263
  }
264
  }
 
 
 
 
 
265
 
266
+ // 转换为 SSE 字节流
267
+ let bytes: Vec<Result<Bytes, Infallible>> = events
268
+ .into_iter()
269
+ .map(|e| Ok(Bytes::from(e.to_sse_string())))
270
+ .collect();
271
 
272
+ Some((stream::iter(bytes), (body_stream, ctx, decoder, false, ping_interval)))
273
+ }
274
+ Some(Err(e)) => {
275
+ tracing::error!("读取响应流失败: {}", e);
276
+ // 发送最终事件并结束
277
+ let final_events = ctx.generate_final_events();
278
+ let bytes: Vec<Result<Bytes, Infallible>> = final_events
279
+ .into_iter()
280
+ .map(|e| Ok(Bytes::from(e.to_sse_string())))
281
+ .collect();
282
+ Some((stream::iter(bytes), (body_stream, ctx, decoder, true, ping_interval)))
283
+ }
284
+ None => {
285
+ // 流结束,发送最终事件
286
+ let final_events = ctx.generate_final_events();
287
+ let bytes: Vec<Result<Bytes, Infallible>> = final_events
288
+ .into_iter()
289
+ .map(|e| Ok(Bytes::from(e.to_sse_string())))
290
+ .collect();
291
+ Some((stream::iter(bytes), (body_stream, ctx, decoder, true, ping_interval)))
292
+ }
293
+ }
294
  }
295
+ // 发送 ping 保活
296
+ _ = ping_interval.tick() => {
297
+ tracing::trace!("发送 ping 保活事件");
298
+ let bytes: Vec<Result<Bytes, Infallible>> = vec![Ok(create_ping_sse())];
299
+ Some((stream::iter(bytes), (body_stream, ctx, decoder, false, ping_interval)))
 
 
 
300
  }
301
  }
302
  },