InsafQ commited on
Commit
aee859c
·
verified ·
1 Parent(s): 43ee677

Add tabgan/llm_api_client.py

Browse files
Files changed (1) hide show
  1. tabgan/llm_api_client.py +219 -0
tabgan/llm_api_client.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ LLM API Client for external text generation via API endpoints.
4
+ """
5
+
6
+ import logging
7
+ import json
8
+ from typing import Optional, Dict, Any, List
9
+ import requests
10
+
11
+ from tabgan.llm_config import LLMAPIConfig
12
+
13
+
14
+ class LLMAPIClient:
15
+ """Client for generating text via external LLM APIs (LM Studio, OpenAI, Ollama, etc.).
16
+
17
+ This client provides a unified interface for API-based text generation
18
+ that can be used alongside or instead of local models.
19
+
20
+ Example:
21
+ from tabgan.llm_config import LLMAPIConfig
22
+ from tabgan.llm_api_client import LLMAPIClient
23
+
24
+ # LM Studio
25
+ config = LLMAPIConfig.from_lm_studio(
26
+ base_url="http://localhost:1234",
27
+ model="google/gemma-3-12b"
28
+ )
29
+ client = LLMAPIClient(config)
30
+
31
+ text = client.generate("Generate a name for a female engineer, Age: 30: ")
32
+ """
33
+
34
+ def __init__(self, config: Optional[LLMAPIConfig] = None):
35
+ """
36
+ Initialize the API client with configuration.
37
+
38
+ Args:
39
+ config: LLMAPIConfig instance. If None, uses default LM Studio config.
40
+ """
41
+ self.config = config or LLMAPIConfig()
42
+ self.session = requests.Session()
43
+
44
+ def generate(self,
45
+ prompt: str,
46
+ max_tokens: Optional[int] = None,
47
+ temperature: Optional[float] = None,
48
+ system_prompt: Optional[str] = None) -> str:
49
+ """
50
+ Generate text from a prompt using the configured API.
51
+
52
+ Args:
53
+ prompt: The text prompt to send to the LLM
54
+ max_tokens: Maximum tokens to generate (overrides config)
55
+ temperature: Sampling temperature (overrides config)
56
+ system_prompt: Optional system prompt (overrides config)
57
+
58
+ Returns:
59
+ Generated text string
60
+
61
+ Raises:
62
+ requests.RequestException: If the API request fails
63
+ """
64
+ headers = self.config.get_headers()
65
+
66
+ # Build request payload based on API type
67
+ if "ollama" in self.config.chat_url or "11434" in self.config.chat_url:
68
+ payload = self._build_ollama_payload(prompt, max_tokens, temperature, system_prompt)
69
+ else:
70
+ # Default to OpenAI-compatible format (LM Studio, OpenAI, etc.)
71
+ payload = self._build_openai_payload(prompt, max_tokens, temperature, system_prompt)
72
+
73
+ try:
74
+ response = self.session.post(
75
+ self.config.chat_url,
76
+ headers=headers,
77
+ json=payload,
78
+ timeout=self.config.timeout
79
+ )
80
+ response.raise_for_status()
81
+
82
+ result = response.json()
83
+ return self._extract_response_text(result)
84
+
85
+ except requests.RequestException as e:
86
+ logging.error(f"LLM API request failed: {e}")
87
+ raise
88
+ except (KeyError, json.JSONDecodeError) as e:
89
+ logging.error(f"Failed to parse LLM API response: {e}")
90
+ raise
91
+
92
+ def _build_openai_payload(self,
93
+ prompt: str,
94
+ max_tokens: Optional[int],
95
+ temperature: Optional[float],
96
+ system_prompt: Optional[str]) -> Dict[str, Any]:
97
+ """Build OpenAI-compatible API request payload."""
98
+ messages: List[Dict[str, str]] = []
99
+
100
+ # Add system message if provided
101
+ sys_prompt = system_prompt or self.config.system_prompt
102
+ if sys_prompt:
103
+ messages.append({"role": "system", "content": sys_prompt})
104
+
105
+ messages.append({"role": "user", "content": prompt})
106
+
107
+ return {
108
+ "model": self.config.model,
109
+ "messages": messages,
110
+ "max_tokens": max_tokens or self.config.max_tokens,
111
+ "temperature": temperature or self.config.temperature,
112
+ "top_p": self.config.top_p,
113
+ }
114
+
115
+ def _build_ollama_payload(self,
116
+ prompt: str,
117
+ max_tokens: Optional[int],
118
+ temperature: Optional[float],
119
+ system_prompt: Optional[str]) -> Dict[str, Any]:
120
+ """Build Ollama API request payload."""
121
+ payload = {
122
+ "model": self.config.model,
123
+ "prompt": prompt,
124
+ "stream": False,
125
+ "options": {
126
+ "temperature": temperature or self.config.temperature,
127
+ "top_p": self.config.top_p,
128
+ "top_k": self.config.top_k,
129
+ }
130
+ }
131
+
132
+ # Add system prompt if provided
133
+ sys_prompt = system_prompt or self.config.system_prompt
134
+ if sys_prompt:
135
+ payload["system"] = sys_prompt
136
+
137
+ # Ollama uses num_predict for max tokens
138
+ if max_tokens:
139
+ payload["options"]["num_predict"] = max_tokens
140
+ elif self.config.max_tokens:
141
+ payload["options"]["num_predict"] = self.config.max_tokens
142
+
143
+ return payload
144
+
145
+ def _extract_response_text(self, result: Dict[str, Any]) -> str:
146
+ """Extract generated text from API response."""
147
+ # OpenAI-compatible format
148
+ if "choices" in result and len(result["choices"]) > 0:
149
+ choice = result["choices"][0]
150
+ if "message" in choice:
151
+ return choice["message"].get("content", "").strip()
152
+ elif "text" in choice:
153
+ return choice["text"].strip()
154
+
155
+ # Ollama format
156
+ if "response" in result:
157
+ return result["response"].strip()
158
+
159
+ # Fallback: try to find any string content
160
+ logging.warning(f"Unexpected API response format: {result}")
161
+ return str(result)
162
+
163
+ def generate_batch(self,
164
+ prompts: List[str],
165
+ max_tokens: Optional[int] = None,
166
+ temperature: Optional[float] = None) -> List[str]:
167
+ """
168
+ Generate text for multiple prompts sequentially.
169
+
170
+ Args:
171
+ prompts: List of prompts to generate from
172
+ max_tokens: Maximum tokens per generation
173
+ temperature: Sampling temperature
174
+
175
+ Returns:
176
+ List of generated text strings
177
+ """
178
+ results = []
179
+ for i, prompt in enumerate(prompts):
180
+ try:
181
+ text = self.generate(prompt, max_tokens, temperature)
182
+ results.append(text)
183
+ except requests.RequestException as e:
184
+ logging.error(f"Failed to generate for prompt {i}: {e}")
185
+ results.append("")
186
+ return results
187
+
188
+ def check_connection(self) -> bool:
189
+ """
190
+ Check if the API endpoint is accessible.
191
+
192
+ Returns:
193
+ True if connection successful, False otherwise
194
+ """
195
+ try:
196
+ # Try to get models list or just check if server responds
197
+ if "ollama" in self.config.base_url or "11434" in self.config.base_url:
198
+ test_url = f"{self.config.base_url.rstrip('/')}/api/tags"
199
+ else:
200
+ # OpenAI-compatible: try /models endpoint
201
+ test_url = f"{self.config.base_url.rstrip('/')}/v1/models"
202
+
203
+ response = self.session.get(
204
+ test_url,
205
+ headers=self.config.get_headers(),
206
+ timeout=5
207
+ )
208
+ return response.status_code == 200
209
+ except requests.RequestException:
210
+ return False
211
+
212
+ def __enter__(self):
213
+ """Context manager entry."""
214
+ return self
215
+
216
+ def __exit__(self, exc_type, exc_val, exc_tb):
217
+ """Context manager exit - close session."""
218
+ self.session.close()
219
+ return False