import os import requests from typing import List from tenacity import retry, stop_after_attempt, wait_random_exponential from models.Base import BaseModel from openai import OpenAI class VLLMModel(BaseModel): def __init__(self, model_id="local-vllm", base_url="http://localhost:8041/v1", api_key=None): """ model_id: 模型名称或标识 (通常不重要,vLLM 会忽略) base_url: vLLM API 地址,例如 http://localhost:8000/v1 api_key: 可选(vLLM 默认不验证) """ self.model_id = model_id client = OpenAI( api_key='EMPTY', base_url=base_url, ) models = client.models.list() self.model_id = models.data[0].id self.base_url = base_url.rstrip("/") self.api_key = api_key # headers self.headers = { "Content-Type": "application/json" } if api_key: self.headers["Authorization"] = f"Bearer {api_key}" # endpoint self.endpoint = f"{self.base_url}/chat/completions" @retry(wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(3)) def generate(self, messages: List, temperature=0, presence_penalty=0, frequency_penalty=0, max_tokens=5000) -> str: """ 调用本地 vLLM 推理接口 """ payload = { "model": self.model_id, "messages": messages, "temperature": temperature, "max_tokens": max_tokens, "presence_penalty": presence_penalty, "frequency_penalty": frequency_penalty, } import ipdb; ipdb.set_trace() response = requests.post(self.endpoint, headers=self.headers, json=payload, timeout=3000) if response.status_code != 200: raise RuntimeError(f"vLLM request failed: {response.status_code}, {response.text}") data = response.json() if "choices" not in data or len(data["choices"]) == 0: raise ValueError("No response choices returned from vLLM API.") return data["choices"][0]["message"]["content"]