File size: 3,788 Bytes
db06013
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from vllm import LLM, SamplingParams
from typing import List, Dict, Any, Optional
import logging
import asyncio
from concurrent.futures import ThreadPoolExecutor

logger = logging.getLogger(__name__)

class VLLMServer:
    def __init__(self, model_name: str = "openai/gpt-oss-20b", 
                 tensor_parallel_size: int = 1, gpu_memory_utilization: float = 0.9):
        self.model_name = model_name
        self.tensor_parallel_size = tensor_parallel_size
        self.gpu_memory_utilization = gpu_memory_utilization
        self.llm = None
        self.executor = ThreadPoolExecutor(max_workers=4)
        
    def initialize(self):
        """Initialize the vLLM model"""
        try:
            self.llm = LLM(
                model=self.model_name,
                tensor_parallel_size=self.tensor_parallel_size,
                gpu_memory_utilization=self.gpu_memory_utilization,
                trust_remote_code=True
            )
            logger.info(f"Initialized vLLM with model: {self.model_name}")
        except Exception as e:
            logger.error(f"Failed to initialize vLLM: {e}")
            raise
    
    def generate(self, prompts: List[str], 
                max_tokens: int = 512,
                temperature: float = 0.7,
                top_p: float = 0.9,
                stop: Optional[List[str]] = None) -> List[Dict[str, Any]]:
        """Generate text for prompts"""
        if self.llm is None:
            self.initialize()
        
        sampling_params = SamplingParams(
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            stop=stop
        )
        
        try:
            outputs = self.llm.generate(prompts, sampling_params)
            
            results = []
            for output in outputs:
                results.append({
                    'text': output.outputs[0].text,
                    'prompt': output.prompt,
                    'finish_reason': output.outputs[0].finish_reason,
                    'token_ids': output.outputs[0].token_ids,
                    'logprobs': getattr(output.outputs[0], 'logprobs', None)
                })
            
            return results
        except Exception as e:
            logger.error(f"Generation failed: {e}")
            raise
    
    def generate_single(self, prompt: str, **kwargs) -> str:
        """Generate text for a single prompt"""
        results = self.generate([prompt], **kwargs)
        return results[0]['text'] if results else ""
    
    def generate_batch(self, prompts: List[str], batch_size: int = 8, **kwargs) -> List[str]:
        """Generate text for multiple prompts in batches"""
        all_results = []
        
        for i in range(0, len(prompts), batch_size):
            batch_prompts = prompts[i:i + batch_size]
            batch_results = self.generate(batch_prompts, **kwargs)
            all_results.extend([r['text'] for r in batch_results])
        
        return all_results
    
    async def generate_async(self, prompts: List[str], **kwargs) -> List[Dict[str, Any]]:
        """Async generation"""
        loop = asyncio.get_event_loop()
        return await loop.run_in_executor(self.executor, self.generate, prompts, **kwargs)
    
    def get_model_info(self) -> Dict[str, Any]:
        """Get model information"""
        if self.llm is None:
            return {}
        
        return {
            'model_name': self.model_name,
            'tensor_parallel_size': self.tensor_parallel_size,
            'gpu_memory_utilization': self.gpu_memory_utilization,
            'is_initialized': self.llm is not None
        }
    
    def cleanup(self):
        """Cleanup resources"""
        if self.executor:
            self.executor.shutdown(wait=True)