Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
| """LLM client wrapper for the Text2SPARQL repair pipeline. | |
| All LLM access goes through this file. Supports three backends: | |
| - "vllm": vLLM inline mode for local models (e.g., Qwen2.5 27B AWQ) | |
| - "openai": OpenAI-compatible API for cloud OpenAI models | |
| - "anthropic": Anthropic Messages API for Claude models | |
| Centralizes retries, parsing, and model selection. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import logging | |
| import os | |
| import re | |
| import time | |
| from typing import Type, TypeVar | |
| from pydantic import BaseModel | |
| logger = logging.getLogger(__name__) | |
| T = TypeVar("T", bound=BaseModel) | |
| # Maximum retries for LLM calls | |
| _MAX_RETRIES = 3 | |
| _RETRY_DELAY_SEC = 1.0 | |
| class LLMClient: | |
| """Unified wrapper for LLM inference via vLLM, OpenAI API, or Anthropic API. | |
| Backend selection: | |
| - Pass backend="vllm" to use vLLM inline. | |
| - Pass backend="openai" to use OpenAI API via OPENAI_API_KEY. | |
| - Pass backend="anthropic" to use Anthropic API via ANTHROPIC_API_KEY. | |
| For vLLM, the model is loaded once and shared across calls. | |
| """ | |
| # Class-level vLLM model cache to avoid reloading for each LLMClient instance | |
| _vllm_instance = None | |
| _vllm_tokenizer = None | |
| _vllm_model_name = None | |
| def __init__( | |
| self, | |
| model_name: str, | |
| temperature: float, | |
| backend: str | None = None, | |
| max_tokens: int = 4096, | |
| gpu_memory_utilization: float = 0.85, | |
| max_model_len: int = 8192, | |
| enforce_eager: bool = True, | |
| ) -> None: | |
| self.model_name = model_name | |
| self.temperature = temperature | |
| self.max_tokens = max_tokens | |
| self.gpu_memory_utilization = gpu_memory_utilization | |
| self.max_model_len = max_model_len | |
| self.enforce_eager = enforce_eager | |
| # Auto-detect backend | |
| if backend is not None: | |
| self.backend = backend | |
| elif os.environ.get("LLM_BACKEND", "").lower() in {"openai", "anthropic"}: | |
| self.backend = os.environ.get("LLM_BACKEND", "").lower() | |
| else: | |
| # Default to vLLM for local models | |
| self.backend = "vllm" | |
| if self.backend == "vllm": | |
| self._init_vllm() | |
| elif self.backend == "openai": | |
| self._init_openai() | |
| elif self.backend == "anthropic": | |
| self._init_anthropic() | |
| else: | |
| raise RuntimeError(f"Unsupported llm backend: {self.backend}") | |
| # ββ vLLM backend βββββββββββββββββββββββββββββββββββββββββββββ | |
| def _init_vllm(self) -> None: | |
| """Initialize vLLM inline engine (shared across instances with same model).""" | |
| if ( | |
| LLMClient._vllm_instance is not None | |
| and LLMClient._vllm_model_name == self.model_name | |
| ): | |
| logger.info("Reusing existing vLLM instance for %s", self.model_name) | |
| return | |
| try: | |
| from vllm import LLM as VllmLLM | |
| from transformers import AutoTokenizer | |
| logger.info("Initializing vLLM model: %s ...", self.model_name) | |
| LLMClient._vllm_instance = VllmLLM( | |
| model=self.model_name, | |
| trust_remote_code=True, | |
| gpu_memory_utilization=self.gpu_memory_utilization, | |
| max_model_len=self.max_model_len, | |
| tensor_parallel_size=1, | |
| enable_prefix_caching=True, | |
| enforce_eager=self.enforce_eager, | |
| ) | |
| LLMClient._vllm_tokenizer = AutoTokenizer.from_pretrained( | |
| self.model_name, trust_remote_code=True | |
| ) | |
| LLMClient._vllm_model_name = self.model_name | |
| logger.info("vLLM model loaded successfully.") | |
| except ImportError as exc: | |
| logger.error( | |
| "vLLM or transformers not installed. " | |
| "Install with: pip install vllm transformers" | |
| ) | |
| raise RuntimeError(f"vLLM backend requires vllm package: {exc}") from exc | |
| def _vllm_generate(self, prompt: str, max_tokens: int | None = None) -> str: | |
| """Generate text using vLLM inline engine.""" | |
| from vllm import SamplingParams | |
| effective_max_tokens = max_tokens or self.max_tokens | |
| tokenizer = LLMClient._vllm_tokenizer | |
| llm_engine = LLMClient._vllm_instance | |
| # Apply chat template (same pattern as your test_qwen.py) | |
| messages = [{"role": "user", "content": prompt}] | |
| formatted = tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| # Truncate if needed | |
| max_allowed = self.max_model_len - effective_max_tokens - 50 | |
| tokens = tokenizer.encode(formatted) | |
| if len(tokens) > max_allowed: | |
| tokens = tokens[:max_allowed] | |
| formatted = tokenizer.decode(tokens, skip_special_tokens=False) | |
| sampling_params = SamplingParams( | |
| temperature=self.temperature, | |
| top_p=1.0 if self.temperature == 0.0 else 0.95, | |
| max_tokens=effective_max_tokens, | |
| ) | |
| outputs = llm_engine.generate([formatted], sampling_params) | |
| if outputs and outputs[0].outputs: | |
| text = outputs[0].outputs[0].text.strip() | |
| # Strip <think>...</think> reasoning blocks if present (Qwen pattern) | |
| text = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL).strip() | |
| return text | |
| return "" | |
| # ββ OpenAI backend βββββββββββββββββββββββββββββββββββββββββββ | |
| def _init_openai(self) -> None: | |
| """Initialize the OpenAI client.""" | |
| try: | |
| from openai import OpenAI | |
| api_key = os.environ.get("OPENAI_API_KEY", "") | |
| if not api_key: | |
| raise RuntimeError( | |
| "OPENAI_API_KEY is not set. Export it before using " | |
| "llm_backend=openai." | |
| ) | |
| self._openai_client = OpenAI( | |
| api_key=api_key, | |
| base_url=os.environ.get("OPENAI_BASE_URL"), | |
| ) | |
| except ImportError: | |
| logger.warning( | |
| "openai package not installed. Install with: pip install openai" | |
| ) | |
| self._openai_client = None | |
| except Exception as exc: | |
| raise RuntimeError(f"Failed to initialize OpenAI client: {exc}") from exc | |
| def _openai_supports_sampling_params(self) -> bool: | |
| """Return whether it is safe to send temperature/top-p style params. | |
| Per current OpenAI docs, GPT-5 family models on Chat Completions | |
| should avoid sampling parameters like `temperature` or `top_p`. | |
| """ | |
| return not self._openai_is_gpt5_family() | |
| def _openai_uses_max_completion_tokens(self) -> bool: | |
| """Return whether the model expects `max_completion_tokens`. | |
| Current GPT-5 family Chat Completions requests reject `max_tokens` | |
| and require `max_completion_tokens` instead. | |
| """ | |
| return self._openai_is_gpt5_family() | |
| def _openai_is_gpt5_family(self) -> bool: | |
| """Return whether the model belongs to the GPT-5 family.""" | |
| normalized = self.model_name.strip().lower() | |
| return normalized.startswith("gpt-5") | |
| def _openai_generate(self, prompt: str) -> str: | |
| """Generate text using OpenAI API.""" | |
| if self._openai_client is None: | |
| raise RuntimeError("OpenAI client not initialized") | |
| request_kwargs = { | |
| "model": self.model_name, | |
| "messages": [{"role": "user", "content": prompt}], | |
| } | |
| if self._openai_uses_max_completion_tokens(): | |
| request_kwargs["max_completion_tokens"] = self.max_tokens | |
| else: | |
| request_kwargs["max_tokens"] = self.max_tokens | |
| if self._openai_supports_sampling_params(): | |
| request_kwargs["temperature"] = self.temperature | |
| response = self._openai_client.chat.completions.create(**request_kwargs) | |
| logger.info( | |
| "OpenAI response model served: requested=%s served=%s", | |
| self.model_name, | |
| getattr(response, "model", "<unknown>"), | |
| ) | |
| return response.choices[0].message.content or "" | |
| # ββ Anthropic backend βββββββββββββββββββββββββββββββββββββββ | |
| def _init_anthropic(self) -> None: | |
| """Initialize the Anthropic client.""" | |
| try: | |
| from anthropic import Anthropic | |
| api_key = os.environ.get("ANTHROPIC_API_KEY", "") | |
| if not api_key: | |
| raise RuntimeError( | |
| "ANTHROPIC_API_KEY is not set. Export it before using " | |
| "llm_backend=anthropic." | |
| ) | |
| self._anthropic_client = Anthropic( | |
| api_key=api_key, | |
| base_url=os.environ.get("ANTHROPIC_BASE_URL"), | |
| ) | |
| except ImportError: | |
| logger.warning( | |
| "anthropic package not installed. Install with: pip install anthropic" | |
| ) | |
| self._anthropic_client = None | |
| except Exception as exc: | |
| raise RuntimeError(f"Failed to initialize Anthropic client: {exc}") from exc | |
| def _anthropic_generate(self, prompt: str) -> str: | |
| """Generate text using the Anthropic Messages API.""" | |
| if self._anthropic_client is None: | |
| raise RuntimeError("Anthropic client not initialized") | |
| response = self._anthropic_client.messages.create( | |
| model=self.model_name, | |
| max_tokens=self.max_tokens, | |
| temperature=self.temperature, | |
| messages=[{"role": "user", "content": prompt}], | |
| ) | |
| logger.info( | |
| "Anthropic response model served: requested=%s served=%s", | |
| self.model_name, | |
| getattr(response, "model", "<unknown>"), | |
| ) | |
| parts: list[str] = [] | |
| for block in getattr(response, "content", []) or []: | |
| if getattr(block, "type", None) == "text": | |
| parts.append(getattr(block, "text", "")) | |
| return "".join(parts).strip() | |
| # ββ Public interface βββββββββββββββββββββββββββββββββββββββββ | |
| def generate_text(self, prompt: str, max_tokens: int | None = None) -> str: | |
| """Generate text from a prompt using the configured backend. | |
| Args: | |
| prompt: The input prompt. | |
| max_tokens: Optional override for max output tokens. | |
| Returns: | |
| Generated text response. | |
| Raises: | |
| RuntimeError: If all retries fail. | |
| """ | |
| for attempt in range(1, _MAX_RETRIES + 1): | |
| try: | |
| if self.backend == "vllm": | |
| content = self._vllm_generate(prompt, max_tokens=max_tokens) | |
| elif self.backend == "openai": | |
| content = self._openai_generate(prompt) | |
| else: | |
| content = self._anthropic_generate(prompt) | |
| logger.debug( | |
| "LLM text response (attempt %d, backend=%s, model=%s): %d chars", | |
| attempt, self.backend, self.model_name, len(content), | |
| ) | |
| return content | |
| except Exception as exc: | |
| logger.warning( | |
| "LLM call attempt %d/%d failed: %s", | |
| attempt, _MAX_RETRIES, exc, | |
| ) | |
| if attempt < _MAX_RETRIES: | |
| time.sleep(_RETRY_DELAY_SEC * attempt) | |
| else: | |
| raise RuntimeError( | |
| f"LLM call failed after {_MAX_RETRIES} attempts: {exc}" | |
| ) from exc | |
| return "" # unreachable | |
| def generate_batch(self, prompts: list[str]) -> list[str]: | |
| """Generate text for multiple prompts, batching when possible. | |
| On vLLM backend, all prompts are passed to a single llm.generate() | |
| call, enabling continuous batching (~2x faster than sequential on 1 GPU). | |
| On OpenAI backend, falls back to sequential generation. | |
| Args: | |
| prompts: List of prompt strings. | |
| Returns: | |
| List of generated text responses (same order as prompts). | |
| """ | |
| if not prompts: | |
| return [] | |
| if self.backend == "vllm": | |
| return self._vllm_generate_batch(prompts) | |
| else: | |
| # Sequential fallback for OpenAI | |
| return [self.generate_text(p) for p in prompts] | |
| def _vllm_generate_batch(self, prompts: list[str]) -> list[str]: | |
| """Batch-generate using vLLM. All prompts processed in one call.""" | |
| from vllm import SamplingParams | |
| tokenizer = LLMClient._vllm_tokenizer | |
| llm_engine = LLMClient._vllm_instance | |
| max_allowed = self.max_model_len - self.max_tokens - 50 | |
| formatted_prompts = [] | |
| for prompt in prompts: | |
| messages = [{"role": "user", "content": prompt}] | |
| formatted = tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| tokens = tokenizer.encode(formatted) | |
| if len(tokens) > max_allowed: | |
| tokens = tokens[:max_allowed] | |
| formatted = tokenizer.decode(tokens, skip_special_tokens=False) | |
| formatted_prompts.append(formatted) | |
| sampling_params = SamplingParams( | |
| temperature=self.temperature, | |
| top_p=1.0 if self.temperature == 0.0 else 0.95, | |
| max_tokens=self.max_tokens, | |
| ) | |
| outputs = llm_engine.generate(formatted_prompts, sampling_params) | |
| results = [] | |
| for output in outputs: | |
| if output.outputs: | |
| text = output.outputs[0].text.strip() | |
| text = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL).strip() | |
| results.append(text) | |
| else: | |
| results.append("") | |
| logger.info( | |
| "Batch generated %d responses (%d total chars)", | |
| len(results), sum(len(r) for r in results), | |
| ) | |
| return results | |
| def generate_json(self, prompt: str, schema_model: Type[T]) -> T: | |
| """Generate structured JSON matching a Pydantic model. | |
| Attempts to parse the LLM response as JSON. Handles markdown code | |
| fences and extracts JSON blocks automatically. If parsing fails, the | |
| retry prompt includes the invalid response and asks for corrected JSON | |
| only, rather than blindly repeating the same prompt. | |
| Args: | |
| prompt: The input prompt requesting JSON output. | |
| schema_model: Pydantic model class to parse the response into. | |
| Returns: | |
| Parsed Pydantic model instance. | |
| Raises: | |
| RuntimeError: If parsing fails after all retries. | |
| """ | |
| raw_text = "" | |
| current_prompt = prompt | |
| for attempt in range(1, _MAX_RETRIES + 1): | |
| try: | |
| raw_text = self.generate_text(current_prompt) | |
| json_str = self._extract_json(raw_text) | |
| parsed = json.loads(json_str) | |
| return schema_model.model_validate(parsed) | |
| except (json.JSONDecodeError, Exception) as exc: | |
| logger.warning( | |
| "JSON parse attempt %d/%d failed: %s\nRaw text: %.500s", | |
| attempt, _MAX_RETRIES, exc, raw_text, | |
| ) | |
| if attempt < _MAX_RETRIES: | |
| current_prompt = self._build_json_retry_prompt( | |
| prompt, | |
| raw_text, | |
| str(exc), | |
| schema_model, | |
| ) | |
| time.sleep(_RETRY_DELAY_SEC * attempt) | |
| else: | |
| raise RuntimeError( | |
| f"Failed to parse LLM JSON output after {_MAX_RETRIES} " | |
| f"attempts: {exc}" | |
| ) from exc | |
| raise RuntimeError("Unreachable") | |
| def _build_json_retry_prompt( | |
| original_prompt: str, | |
| invalid_response: str, | |
| error: str, | |
| schema_model: Type[BaseModel], | |
| ) -> str: | |
| """Build a corrective retry prompt after invalid JSON output.""" | |
| schema = json.dumps(schema_model.model_json_schema(), ensure_ascii=False) | |
| return ( | |
| f"{original_prompt}\n\n" | |
| "# Invalid JSON Retry\n" | |
| "Your previous answer could not be parsed as the required JSON object.\n" | |
| f"Parser/schema error: {error}\n\n" | |
| "Previous invalid answer:\n" | |
| "```text\n" | |
| f"{invalid_response[:2500]}\n" | |
| "```\n\n" | |
| "Required JSON schema:\n" | |
| "```json\n" | |
| f"{schema[:4000]}\n" | |
| "```\n\n" | |
| "Return ONLY one valid JSON object. Do not include markdown, comments, " | |
| "explanations, or trailing text." | |
| ) | |
| def _extract_json(text: str) -> str: | |
| """Extract JSON from text, handling markdown code fences. | |
| Strategies (in order): | |
| 1. Extract from ```json ... ``` fences | |
| 2. Extract from ``` ... ``` fences (if starts with {) | |
| 3. Find first { ... } block | |
| 4. Return text as-is | |
| """ | |
| # Strategy 1: ```json ... ``` | |
| match = re.search(r"```json\s*\n?(.*?)\n?\s*```", text, re.DOTALL) | |
| if match: | |
| return match.group(1).strip() | |
| # Strategy 2: ``` ... ``` | |
| match = re.search(r"```\s*\n?(.*?)\n?\s*```", text, re.DOTALL) | |
| if match: | |
| candidate = match.group(1).strip() | |
| if candidate.startswith("{"): | |
| return candidate | |
| # Strategy 3: find { ... } | |
| start = text.find("{") | |
| end = text.rfind("}") | |
| if start != -1 and end != -1 and end > start: | |
| return text[start : end + 1] | |
| # Strategy 4: return as-is | |
| return text.strip() | |