for-the-zero commited on
Commit
e84b7dc
·
1 Parent(s): d26bfee
Files changed (4) hide show
  1. README.md +2 -2
  2. __pycache__/app.cpython-312.pyc +0 -0
  3. app.py +162 -46
  4. best.pt +2 -2
README.md CHANGED
@@ -4,8 +4,8 @@ emoji: 🌐
4
  colorFrom: blue
5
  colorTo: purple
6
  sdk: gradio
7
- sdk_version: 5.23.0
8
- python_version: "3.12"
9
  app_file: app.py
10
  pinned: false
11
  license: mit
 
4
  colorFrom: blue
5
  colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 4.44.0
8
+ python_version: '3.12'
9
  app_file: app.py
10
  pinned: false
11
  license: mit
__pycache__/app.cpython-312.pyc CHANGED
Binary files a/__pycache__/app.cpython-312.pyc and b/__pycache__/app.cpython-312.pyc differ
 
app.py CHANGED
@@ -40,6 +40,8 @@ class DiffusionConfig:
40
  beta_start: float = 0.0001
41
  beta_end: float = 0.02
42
  length_noise_scale: float = 0.3
 
 
43
 
44
 
45
  @dataclass
@@ -372,83 +374,132 @@ class DualOutputProjection(nn.Module):
372
 
373
 
374
  class MultiHeadAttention(nn.Module):
375
- """多头自注意力"""
376
 
377
  def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
378
  super().__init__()
379
  assert d_model % n_heads == 0
 
380
  self.d_model = d_model
381
  self.n_heads = n_heads
382
  self.d_k = d_model // n_heads
383
- self.w_q = nn.Linear(d_model, d_model)
384
- self.w_k = nn.Linear(d_model, d_model)
385
- self.w_v = nn.Linear(d_model, d_model)
386
- self.w_o = nn.Linear(d_model, d_model)
 
 
387
  self.dropout = nn.Dropout(dropout)
388
 
389
  def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
390
  batch_size = q.size(0)
391
- q = self.w_q(q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
392
- k = self.w_k(k).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
393
- v = self.w_v(v).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
394
- scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
 
 
 
 
 
 
 
 
 
 
395
  if mask is not None:
396
  scores = scores.masked_fill(mask == 0, float('-inf'))
 
397
  attn = F.softmax(scores, dim=-1)
398
  attn = self.dropout(attn)
 
 
399
  out = torch.matmul(attn, v)
400
- out = out.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
 
401
  return self.w_o(out)
402
 
403
 
404
  class FeedForward(nn.Module):
405
- """前馈网络"""
406
 
407
- def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
408
  super().__init__()
409
- self.w1 = nn.Linear(d_model, d_ff)
410
- self.w2 = nn.Linear(d_ff, d_model)
 
 
 
 
 
 
 
 
411
  self.dropout = nn.Dropout(dropout)
412
 
413
  def forward(self, x: torch.Tensor) -> torch.Tensor:
414
- return self.dropout(self.w2(F.gelu(self.w1(x))))
 
 
 
 
 
415
 
416
 
417
  class TransformerBlock(nn.Module):
418
- """Transformer块"""
419
 
420
  def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
421
  super().__init__()
422
- self.attn = MultiHeadAttention(d_model, n_heads, dropout)
423
- self.ff = FeedForward(d_model, d_ff, dropout)
424
  self.norm1 = nn.LayerNorm(d_model)
 
 
425
  self.norm2 = nn.LayerNorm(d_model)
 
 
426
  self.dropout = nn.Dropout(dropout)
427
 
428
  def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
 
429
  x = x + self.dropout(self.attn(self.norm1(x), self.norm1(x), self.norm1(x), mask))
 
430
  x = x + self.dropout(self.ff(self.norm2(x)))
431
  return x
432
 
433
 
434
  class DualNoisePredictor(nn.Module):
435
- """双语言噪声预测器"""
436
 
437
  def __init__(self, d_model: int = 256, n_heads: int = 4, n_layers: int = 4, d_ff: int = 512, max_len: int = 128, dropout: float = 0.1):
438
  super().__init__()
439
  self.d_model = d_model
440
 
441
- # 时间步嵌入(共享)
442
- self.time_embedding = SinusoidalTimeEmbedding(d_model)
443
- self.time_mlp = nn.Sequential(
 
444
  nn.Linear(d_model, d_model * 4),
445
- nn.GELU(),
 
 
446
  nn.Linear(d_model * 4, d_model),
447
  )
448
 
449
- # 语言特定的输入投影
450
- self.zh_input_proj = nn.Linear(d_model, d_model)
451
- self.en_input_proj = nn.Linear(d_model, d_model)
 
 
 
 
 
 
 
 
 
 
452
 
453
  # 共享Transformer层
454
  self.layers = nn.ModuleList([
@@ -456,16 +507,26 @@ class DualNoisePredictor(nn.Module):
456
  for _ in range(n_layers)
457
  ])
458
 
459
- # 语言特定的输出投影
460
- self.zh_output_proj = nn.Linear(d_model, d_model)
461
- self.en_output_proj = nn.Linear(d_model, d_model)
 
 
 
 
 
 
 
 
 
 
462
 
463
  self.output_norm = nn.LayerNorm(d_model)
464
 
465
  def forward(self, x_t: torch.Tensor, t: torch.Tensor, lang: str = "zh", mask: Optional[torch.Tensor] = None) -> torch.Tensor:
466
  # 时间步嵌入
467
- t_emb = self.time_embedding(t)
468
- t_emb = self.time_mlp(t_emb)
469
 
470
  # 语言特定输入投影
471
  if lang == "zh":
@@ -530,10 +591,13 @@ class LanguageSwitcher(nn.Module):
530
 
531
 
532
  # ==================== 扩散过程 ====================
533
- class Diffusion:
 
 
534
  def __init__(self, config: DiffusionConfig):
535
  self.config = config
536
  self.timesteps = config.timesteps
 
537
 
538
  # Beta schedule (linear)
539
  betas = torch.linspace(config.beta_start, config.beta_end, self.timesteps)
@@ -549,15 +613,64 @@ class Diffusion:
549
  def register_buffer(self, name: str, tensor: torch.Tensor):
550
  setattr(self, name, tensor)
551
 
552
- def q_sample(self, x_0: torch.Tensor, t: torch.Tensor, noise: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
553
  if noise is None:
554
- noise = torch.randn_like(x_0)
 
 
 
 
 
 
 
 
555
  sqrt_alpha = self.sqrt_alphas_cumprod[t]
556
  sqrt_one_minus_alpha = self.sqrt_one_minus_alphas_cumprod[t]
557
- x_t = sqrt_alpha.view(-1, 1, 1) * x_0 + sqrt_one_minus_alpha.view(-1, 1, 1) * noise
 
558
  return x_t, noise
559
 
560
  def p_sample(self, x_t: torch.Tensor, t: torch.Tensor, predicted_noise: torch.Tensor) -> torch.Tensor:
 
561
  beta = self.betas[t]
562
  sqrt_one_minus_alpha = self.sqrt_one_minus_alphas_cumprod[t]
563
  sqrt_recip_alpha = 1.0 / torch.sqrt(self.alphas[t])
@@ -576,7 +689,7 @@ class Diffusion:
576
 
577
 
578
  class DDIMSampler:
579
- def __init__(self, diffusion: Diffusion, ddim_steps: int = 50):
580
  self.diffusion = diffusion
581
  self.ddim_steps = ddim_steps
582
 
@@ -651,7 +764,7 @@ class Translator:
651
  dropout=0.0,
652
  )
653
 
654
- self.diffusion = Diffusion(self.diffusion_config)
655
 
656
  # 加载权重
657
  self._load_checkpoint(os.path.join(model_dir, "best.pt"))
@@ -690,7 +803,7 @@ class Translator:
690
  ddim_steps: int = 50,
691
  show_process: bool = False,
692
  ) -> Tuple[str, List[str]]:
693
- """翻译文本,返回结果和中间过程"""
694
  self.model.eval()
695
  self.embedding.eval()
696
  self.output_proj.eval()
@@ -709,24 +822,27 @@ class Translator:
709
  # 嵌入源语言
710
  source_emb = self.embedding(source_ids, source_lang, source_len)
711
 
712
- # 前向扩散到纯噪声
713
  batch_size = source_emb.size(0)
714
- t_full = torch.full((batch_size,), self.diffusion_config.timesteps - 1, dtype=torch.long)
715
- noise = torch.randn_like(source_emb)
716
- x_t, _ = self.diffusion.q_sample(source_emb, t_full, noise)
 
 
 
717
 
718
  # DDIM反向扩散
719
  timesteps = ddim_sampler.ddim_timesteps
720
  total_steps = len(timesteps)
721
- switch_point = total_steps // 2
722
 
723
  process_steps = []
724
 
725
  for i, t in enumerate(timesteps[:-1]):
726
  t_prev = timesteps[i + 1]
727
 
728
- # 语言切换
729
- if i < switch_point:
 
730
  current_lang = source_lang
731
  else:
732
  current_lang = target_lang
@@ -739,7 +855,7 @@ class Translator:
739
  if show_process and i % max(1, total_steps // 10) == 0:
740
  current_ids = self._embed_to_tokens(x_t, current_lang)
741
  current_text = self._decode(current_ids, current_lang)
742
- process_steps.append(f"Step {t.item()}: {current_text[:50]}")
743
 
744
  # DDIM步骤
745
  x_t = ddim_sampler.ddim_step(x_t, t.item(), t_prev.item(), predicted_noise, eta=0.0)
 
40
  beta_start: float = 0.0001
41
  beta_end: float = 0.02
42
  length_noise_scale: float = 0.3
43
+ interpolation_strength: float = 0.8 # 语言插值强度
44
+ cross_lingual_mode: bool = True # 跨语言扩散模式
45
 
46
 
47
  @dataclass
 
374
 
375
 
376
  class MultiHeadAttention(nn.Module):
377
+ """改进的多头自注意力 - 合并 QKV 投影"""
378
 
379
  def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
380
  super().__init__()
381
  assert d_model % n_heads == 0
382
+
383
  self.d_model = d_model
384
  self.n_heads = n_heads
385
  self.d_k = d_model // n_heads
386
+ self.scale = math.sqrt(self.d_k)
387
+
388
+ # 合并 QKV 投影(更高效)
389
+ self.qkv = nn.Linear(d_model, d_model * 3, bias=False)
390
+ self.w_o = nn.Linear(d_model, d_model, bias=False)
391
+
392
  self.dropout = nn.Dropout(dropout)
393
 
394
  def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
395
  batch_size = q.size(0)
396
+ seq_len = q.size(1)
397
+
398
+ # 合并计算 QKV
399
+ qkv = self.qkv(q)
400
+ q, k, v = qkv.chunk(3, dim=-1)
401
+
402
+ # 分头
403
+ q = q.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
404
+ k = k.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
405
+ v = v.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
406
+
407
+ # 注意力计算
408
+ scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale
409
+
410
  if mask is not None:
411
  scores = scores.masked_fill(mask == 0, float('-inf'))
412
+
413
  attn = F.softmax(scores, dim=-1)
414
  attn = self.dropout(attn)
415
+
416
+ # 合并头
417
  out = torch.matmul(attn, v)
418
+ out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
419
+
420
  return self.w_o(out)
421
 
422
 
423
  class FeedForward(nn.Module):
424
+ """前馈网络 - 使用 GLU 结构"""
425
 
426
+ def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1, use_glu: bool = True):
427
  super().__init__()
428
+ self.use_glu = use_glu
429
+
430
+ if use_glu:
431
+ # GLU 结构 - 更好的表达能力
432
+ self.w1 = nn.Linear(d_model, d_ff * 2)
433
+ self.w2 = nn.Linear(d_ff, d_model)
434
+ else:
435
+ self.w1 = nn.Linear(d_model, d_ff)
436
+ self.w2 = nn.Linear(d_ff, d_model)
437
+
438
  self.dropout = nn.Dropout(dropout)
439
 
440
  def forward(self, x: torch.Tensor) -> torch.Tensor:
441
+ if self.use_glu:
442
+ x, gate = self.w1(x).chunk(2, dim=-1)
443
+ x = F.gelu(x) * F.gelu(gate)
444
+ else:
445
+ x = F.gelu(self.w1(x))
446
+ return self.dropout(self.w2(x))
447
 
448
 
449
  class TransformerBlock(nn.Module):
450
+ """Transformer块 - Pre-LayerNorm 结构"""
451
 
452
  def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
453
  super().__init__()
454
+
455
+ # Pre-LayerNorm 结构
456
  self.norm1 = nn.LayerNorm(d_model)
457
+ self.attn = MultiHeadAttention(d_model, n_heads, dropout)
458
+
459
  self.norm2 = nn.LayerNorm(d_model)
460
+ self.ff = FeedForward(d_model, d_ff, dropout, use_glu=True)
461
+
462
  self.dropout = nn.Dropout(dropout)
463
 
464
  def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
465
+ # 自注意力 + 残差 (Pre-LN)
466
  x = x + self.dropout(self.attn(self.norm1(x), self.norm1(x), self.norm1(x), mask))
467
+ # 前馈 + 残差 (Pre-LN)
468
  x = x + self.dropout(self.ff(self.norm2(x)))
469
  return x
470
 
471
 
472
  class DualNoisePredictor(nn.Module):
473
+ """双语言噪声预测器 - 使用先进架构"""
474
 
475
  def __init__(self, d_model: int = 256, n_heads: int = 4, n_layers: int = 4, d_ff: int = 512, max_len: int = 128, dropout: float = 0.1):
476
  super().__init__()
477
  self.d_model = d_model
478
 
479
+ # 时间步嵌入(共享)- 使用 TimeEmbedding 结构
480
+ self.time_embedding = nn.Module()
481
+ self.time_embedding.sinusoidal = SinusoidalTimeEmbedding(d_model)
482
+ self.time_embedding.mlp = nn.Sequential(
483
  nn.Linear(d_model, d_model * 4),
484
+ nn.SiLU(),
485
+ nn.Linear(d_model * 4, d_model * 4),
486
+ nn.SiLU(),
487
  nn.Linear(d_model * 4, d_model),
488
  )
489
 
490
+ # 语言特定的输入投影(多层)
491
+ self.zh_input_proj = nn.Sequential(
492
+ nn.Linear(d_model, d_model),
493
+ nn.LayerNorm(d_model),
494
+ nn.GELU(),
495
+ nn.Linear(d_model, d_model),
496
+ )
497
+ self.en_input_proj = nn.Sequential(
498
+ nn.Linear(d_model, d_model),
499
+ nn.LayerNorm(d_model),
500
+ nn.GELU(),
501
+ nn.Linear(d_model, d_model),
502
+ )
503
 
504
  # 共享Transformer层
505
  self.layers = nn.ModuleList([
 
507
  for _ in range(n_layers)
508
  ])
509
 
510
+ # 语言特定的输出投影(多层)
511
+ self.zh_output_proj = nn.Sequential(
512
+ nn.Linear(d_model, d_model),
513
+ nn.LayerNorm(d_model),
514
+ nn.GELU(),
515
+ nn.Linear(d_model, d_model),
516
+ )
517
+ self.en_output_proj = nn.Sequential(
518
+ nn.Linear(d_model, d_model),
519
+ nn.LayerNorm(d_model),
520
+ nn.GELU(),
521
+ nn.Linear(d_model, d_model),
522
+ )
523
 
524
  self.output_norm = nn.LayerNorm(d_model)
525
 
526
  def forward(self, x_t: torch.Tensor, t: torch.Tensor, lang: str = "zh", mask: Optional[torch.Tensor] = None) -> torch.Tensor:
527
  # 时间步嵌入
528
+ t_emb = self.time_embedding.sinusoidal(t)
529
+ t_emb = self.time_embedding.mlp(t_emb)
530
 
531
  # 语言特定输入投影
532
  if lang == "zh":
 
591
 
592
 
593
  # ==================== 扩散过程 ====================
594
+ class CrossLingualDiffusion:
595
+ """跨语言扩散模型:支持源语言和目标语言之间的插值"""
596
+
597
  def __init__(self, config: DiffusionConfig):
598
  self.config = config
599
  self.timesteps = config.timesteps
600
+ self.interpolation_strength = config.interpolation_strength
601
 
602
  # Beta schedule (linear)
603
  betas = torch.linspace(config.beta_start, config.beta_end, self.timesteps)
 
613
  def register_buffer(self, name: str, tensor: torch.Tensor):
614
  setattr(self, name, tensor)
615
 
616
+ def get_interpolation_factor(self, t: torch.Tensor) -> torch.Tensor:
617
+ """计算插值因子(smoothstep平滑过渡)"""
618
+ normalized_t = t.float() / self.timesteps
619
+ # smoothstep: 3t^2 - 2t^3
620
+ factor = normalized_t * normalized_t * (3 - 2 * normalized_t)
621
+ return factor * self.interpolation_strength
622
+
623
+ def _align_sequences(self, x_source: torch.Tensor, x_target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]:
624
+ """对齐两个序列到相同长度"""
625
+ source_len = x_source.size(1)
626
+ target_len = x_target.size(1)
627
+ target_seq_len = max(source_len, target_len)
628
+
629
+ if source_len < target_seq_len:
630
+ # 填充源序列
631
+ pad_len = target_seq_len - source_len
632
+ x_source_aligned = F.pad(x_source, (0, 0, 0, pad_len))
633
+ else:
634
+ x_source_aligned = x_source
635
+
636
+ if target_len < target_seq_len:
637
+ # 填充目标序列
638
+ pad_len = target_seq_len - target_len
639
+ x_target_aligned = F.pad(x_target, (0, 0, 0, pad_len))
640
+ else:
641
+ x_target_aligned = x_target
642
+
643
+ return x_source_aligned, x_target_aligned, target_seq_len
644
+
645
+ def q_sample(
646
+ self,
647
+ x_source: torch.Tensor,
648
+ x_target: torch.Tensor,
649
+ t: torch.Tensor,
650
+ noise: Optional[torch.Tensor] = None
651
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
652
+ """跨语言前向扩散:源语言和目标语言之间的插值 + 加噪"""
653
+ # 对齐序列
654
+ x_source_aligned, x_target_aligned, seq_len = self._align_sequences(x_source, x_target)
655
+
656
  if noise is None:
657
+ noise = torch.randn_like(x_source_aligned)
658
+
659
+ # 计算插值因子
660
+ interp_factor = self.get_interpolation_factor(t).view(-1, 1, 1)
661
+
662
+ # 插值:从源语言逐渐过渡到目标语言
663
+ x_interp = (1 - interp_factor) * x_source_aligned + interp_factor * x_target_aligned
664
+
665
+ # 加噪
666
  sqrt_alpha = self.sqrt_alphas_cumprod[t]
667
  sqrt_one_minus_alpha = self.sqrt_one_minus_alphas_cumprod[t]
668
+ x_t = sqrt_alpha.view(-1, 1, 1) * x_interp + sqrt_one_minus_alpha.view(-1, 1, 1) * noise
669
+
670
  return x_t, noise
671
 
672
  def p_sample(self, x_t: torch.Tensor, t: torch.Tensor, predicted_noise: torch.Tensor) -> torch.Tensor:
673
+ """反向扩散单步"""
674
  beta = self.betas[t]
675
  sqrt_one_minus_alpha = self.sqrt_one_minus_alphas_cumprod[t]
676
  sqrt_recip_alpha = 1.0 / torch.sqrt(self.alphas[t])
 
689
 
690
 
691
  class DDIMSampler:
692
+ def __init__(self, diffusion: CrossLingualDiffusion, ddim_steps: int = 50):
693
  self.diffusion = diffusion
694
  self.ddim_steps = ddim_steps
695
 
 
764
  dropout=0.0,
765
  )
766
 
767
+ self.diffusion = CrossLingualDiffusion(self.diffusion_config)
768
 
769
  # 加载权重
770
  self._load_checkpoint(os.path.join(model_dir, "best.pt"))
 
803
  ddim_steps: int = 50,
804
  show_process: bool = False,
805
  ) -> Tuple[str, List[str]]:
806
+ """翻译文本,返回结果和中间过程(跨语言扩散)"""
807
  self.model.eval()
808
  self.embedding.eval()
809
  self.output_proj.eval()
 
822
  # 嵌入源语言
823
  source_emb = self.embedding(source_ids, source_lang, source_len)
824
 
825
+ # 初始状态:使用跨语言扩散的前向过程
826
  batch_size = source_emb.size(0)
827
+ t_start = torch.full((batch_size,), self.diffusion_config.timesteps - 1, dtype=torch.long)
828
+ noise_start = torch.randn_like(source_emb)
829
+
830
+ # 模拟目标语言嵌入(随机噪声 + 源语言信息)
831
+ target_emb_fake = torch.randn_like(source_emb) * 0.3 + source_emb * 0.7
832
+ x_t, _ = self.diffusion.q_sample(source_emb, target_emb_fake, t_start, noise_start)
833
 
834
  # DDIM反向扩散
835
  timesteps = ddim_sampler.ddim_timesteps
836
  total_steps = len(timesteps)
 
837
 
838
  process_steps = []
839
 
840
  for i, t in enumerate(timesteps[:-1]):
841
  t_prev = timesteps[i + 1]
842
 
843
+ # 计算进度,决定当前语言
844
+ progress = i / total_steps
845
+ if progress < 0.3:
846
  current_lang = source_lang
847
  else:
848
  current_lang = target_lang
 
855
  if show_process and i % max(1, total_steps // 10) == 0:
856
  current_ids = self._embed_to_tokens(x_t, current_lang)
857
  current_text = self._decode(current_ids, current_lang)
858
+ process_steps.append(f"Step {t.item()} [{current_lang}]: {current_text[:50]}")
859
 
860
  # DDIM步骤
861
  x_t = ddim_sampler.ddim_step(x_t, t.item(), t_prev.item(), predicted_noise, eta=0.0)
best.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:283c995651aa11ebde09858ad5002cd932c1dd6dd5ede16be733c16cbb5c4c55
3
- size 47986610
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7b85cd0723b960f4f579c9b39e9ce87c7981feadf8184bab31fff128798f9bae
3
+ size 70043898