File size: 3,442 Bytes
24c2665
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
"""
VLLM 공유를 위한 Custom Rollout Worker
Step 1-4의 Ray Actor VLLM을 Step 5에서 재사용
"""
import ray
from typing import Optional, Any, Dict, List
from verl.workers.rollout.vllm_rollout import vLLMRollout
from verl.protocol import DataProto
import torch


class SharedVLLMRollout(vLLMRollout):
    """기존 Ray Actor의 VLLM을 재사용하는 Rollout Worker"""
    
    def __init__(self, 
                 actor_handle: Optional[Any] = None,
                 *args, **kwargs):
        """
        Args:
            actor_handle: Step 1-4의 RemoteTestTimePipeline Ray Actor 참조
        """
        self.existing_vllm_actor = actor_handle
        
        if self.existing_vllm_actor is not None:
            print(f"🔗 Using existing VLLM Actor: {self.existing_vllm_actor}")
            # VLLM 엔진 생성을 스킵하고 Actor 참조만 저장
            self.use_external_vllm = True
            # 부모 클래스의 __init__을 호출하되, 모델 로딩은 스킵
            kwargs['skip_model_loading'] = True
        else:
            print("⚠️ No existing VLLM Actor provided, creating new VLLM")
            self.use_external_vllm = False
            
        super().__init__(*args, **kwargs)
    
    def generate(self, 
                 prompts: List[str],
                 sampling_params: Dict[str, Any],
                 **kwargs) -> DataProto:
        """
        텍스트 생성 - 기존 VLLM Actor 사용
        """
        if self.use_external_vllm and self.existing_vllm_actor is not None:
            # Ray Actor의 generate 메서드 호출
            print(f"📡 Calling remote VLLM Actor for {len(prompts)} prompts")
            
            # RemoteTestTimePipeline의 generate 메서드 호출
            # 이 부분은 RemoteTestTimePipeline의 인터페이스에 맞게 조정 필요
            result = ray.get(
                self.existing_vllm_actor.generate_batch_vllm.remote(
                    prompts=prompts,
                    max_tokens=sampling_params.get('max_tokens', 512),
                    temperature=sampling_params.get('temperature', 0.7),
                    top_p=sampling_params.get('top_p', 1.0),
                    n=sampling_params.get('n', 1)
                )
            )
            
            # 결과를 DataProto 형식으로 변환
            return self._convert_to_dataproto(result)
        else:
            # 기존 VLLM 사용 (부모 클래스 메서드)
            return super().generate(prompts, sampling_params, **kwargs)
    
    def _convert_to_dataproto(self, result: Dict[str, Any]) -> DataProto:
        """Ray Actor 결과를 DataProto로 변환"""
        # RemoteTestTimePipeline의 출력 형식에 맞게 조정 필요
        # 예시 구현:
        responses = result.get('responses', [])
        
        data_dict = {
            'responses': responses,
            'input_ids': result.get('input_ids', []),
            'attention_mask': result.get('attention_mask', [])
        }
        
        return DataProto.from_single_dict(data_dict)
    
    def update_weight(self, state_dict: Dict[str, torch.Tensor]):
        """
        모델 가중치 업데이트 - VLLM 공유 시에는 스킵
        """
        if self.use_external_vllm:
            print("🔄 Skipping weight update for shared VLLM (handled by Ray Actor)")
            return
        else:
            super().update_weight(state_dict)