File size: 2,266 Bytes
02c783d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
import requests
from typing import List
from tenacity import retry, stop_after_attempt, wait_random_exponential
from models.Base import BaseModel
from openai import OpenAI


class VLLMModel(BaseModel):
    def __init__(self,
                 model_id="local-vllm",
                 base_url="http://localhost:8041/v1",
                 api_key=None):
        """
        model_id: 模型名称或标识 (通常不重要,vLLM 会忽略)
        base_url: vLLM API 地址,例如 http://localhost:8000/v1
        api_key: 可选(vLLM 默认不验证)
        """
        self.model_id = model_id

        client = OpenAI(
            api_key='EMPTY',
            base_url=base_url,
        )

        models = client.models.list()
        self.model_id = models.data[0].id

        self.base_url = base_url.rstrip("/")
        self.api_key = api_key

        # headers
        self.headers = {
            "Content-Type": "application/json"
        }
        if api_key:
            self.headers["Authorization"] = f"Bearer {api_key}"

        # endpoint
        self.endpoint = f"{self.base_url}/chat/completions"

    @retry(wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(3))
    def generate(self,
                 messages: List,
                 temperature=0,
                 presence_penalty=0,
                 frequency_penalty=0,
                 max_tokens=5000) -> str:
        """
        调用本地 vLLM 推理接口
        """
        payload = {
            "model": self.model_id,
            "messages": messages,
            "temperature": temperature,
            "max_tokens": max_tokens,
            "presence_penalty": presence_penalty,
            "frequency_penalty": frequency_penalty,
        }

        import ipdb; ipdb.set_trace()
        response = requests.post(self.endpoint, headers=self.headers, json=payload, timeout=3000)

        if response.status_code != 200:
            raise RuntimeError(f"vLLM request failed: {response.status_code}, {response.text}")

        data = response.json()
        if "choices" not in data or len(data["choices"]) == 0:
            raise ValueError("No response choices returned from vLLM API.")

        return data["choices"][0]["message"]["content"]