TabGAN / tabgan /llm_api_client.py
InsafQ's picture
Add tabgan/llm_api_client.py
aee859c verified
# -*- coding: utf-8 -*-
"""
LLM API Client for external text generation via API endpoints.
"""
import logging
import json
from typing import Optional, Dict, Any, List
import requests
from tabgan.llm_config import LLMAPIConfig
class LLMAPIClient:
"""Client for generating text via external LLM APIs (LM Studio, OpenAI, Ollama, etc.).
This client provides a unified interface for API-based text generation
that can be used alongside or instead of local models.
Example:
from tabgan.llm_config import LLMAPIConfig
from tabgan.llm_api_client import LLMAPIClient
# LM Studio
config = LLMAPIConfig.from_lm_studio(
base_url="http://localhost:1234",
model="google/gemma-3-12b"
)
client = LLMAPIClient(config)
text = client.generate("Generate a name for a female engineer, Age: 30: ")
"""
def __init__(self, config: Optional[LLMAPIConfig] = None):
"""
Initialize the API client with configuration.
Args:
config: LLMAPIConfig instance. If None, uses default LM Studio config.
"""
self.config = config or LLMAPIConfig()
self.session = requests.Session()
def generate(self,
prompt: str,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
system_prompt: Optional[str] = None) -> str:
"""
Generate text from a prompt using the configured API.
Args:
prompt: The text prompt to send to the LLM
max_tokens: Maximum tokens to generate (overrides config)
temperature: Sampling temperature (overrides config)
system_prompt: Optional system prompt (overrides config)
Returns:
Generated text string
Raises:
requests.RequestException: If the API request fails
"""
headers = self.config.get_headers()
# Build request payload based on API type
if "ollama" in self.config.chat_url or "11434" in self.config.chat_url:
payload = self._build_ollama_payload(prompt, max_tokens, temperature, system_prompt)
else:
# Default to OpenAI-compatible format (LM Studio, OpenAI, etc.)
payload = self._build_openai_payload(prompt, max_tokens, temperature, system_prompt)
try:
response = self.session.post(
self.config.chat_url,
headers=headers,
json=payload,
timeout=self.config.timeout
)
response.raise_for_status()
result = response.json()
return self._extract_response_text(result)
except requests.RequestException as e:
logging.error(f"LLM API request failed: {e}")
raise
except (KeyError, json.JSONDecodeError) as e:
logging.error(f"Failed to parse LLM API response: {e}")
raise
def _build_openai_payload(self,
prompt: str,
max_tokens: Optional[int],
temperature: Optional[float],
system_prompt: Optional[str]) -> Dict[str, Any]:
"""Build OpenAI-compatible API request payload."""
messages: List[Dict[str, str]] = []
# Add system message if provided
sys_prompt = system_prompt or self.config.system_prompt
if sys_prompt:
messages.append({"role": "system", "content": sys_prompt})
messages.append({"role": "user", "content": prompt})
return {
"model": self.config.model,
"messages": messages,
"max_tokens": max_tokens or self.config.max_tokens,
"temperature": temperature or self.config.temperature,
"top_p": self.config.top_p,
}
def _build_ollama_payload(self,
prompt: str,
max_tokens: Optional[int],
temperature: Optional[float],
system_prompt: Optional[str]) -> Dict[str, Any]:
"""Build Ollama API request payload."""
payload = {
"model": self.config.model,
"prompt": prompt,
"stream": False,
"options": {
"temperature": temperature or self.config.temperature,
"top_p": self.config.top_p,
"top_k": self.config.top_k,
}
}
# Add system prompt if provided
sys_prompt = system_prompt or self.config.system_prompt
if sys_prompt:
payload["system"] = sys_prompt
# Ollama uses num_predict for max tokens
if max_tokens:
payload["options"]["num_predict"] = max_tokens
elif self.config.max_tokens:
payload["options"]["num_predict"] = self.config.max_tokens
return payload
def _extract_response_text(self, result: Dict[str, Any]) -> str:
"""Extract generated text from API response."""
# OpenAI-compatible format
if "choices" in result and len(result["choices"]) > 0:
choice = result["choices"][0]
if "message" in choice:
return choice["message"].get("content", "").strip()
elif "text" in choice:
return choice["text"].strip()
# Ollama format
if "response" in result:
return result["response"].strip()
# Fallback: try to find any string content
logging.warning(f"Unexpected API response format: {result}")
return str(result)
def generate_batch(self,
prompts: List[str],
max_tokens: Optional[int] = None,
temperature: Optional[float] = None) -> List[str]:
"""
Generate text for multiple prompts sequentially.
Args:
prompts: List of prompts to generate from
max_tokens: Maximum tokens per generation
temperature: Sampling temperature
Returns:
List of generated text strings
"""
results = []
for i, prompt in enumerate(prompts):
try:
text = self.generate(prompt, max_tokens, temperature)
results.append(text)
except requests.RequestException as e:
logging.error(f"Failed to generate for prompt {i}: {e}")
results.append("")
return results
def check_connection(self) -> bool:
"""
Check if the API endpoint is accessible.
Returns:
True if connection successful, False otherwise
"""
try:
# Try to get models list or just check if server responds
if "ollama" in self.config.base_url or "11434" in self.config.base_url:
test_url = f"{self.config.base_url.rstrip('/')}/api/tags"
else:
# OpenAI-compatible: try /models endpoint
test_url = f"{self.config.base_url.rstrip('/')}/v1/models"
response = self.session.get(
test_url,
headers=self.config.get_headers(),
timeout=5
)
return response.status_code == 200
except requests.RequestException:
return False
def __enter__(self):
"""Context manager entry."""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit - close session."""
self.session.close()
return False