CB commited on
Commit
3b84c09
·
verified ·
1 Parent(s): a20ea0f

Create genai_client.py

Browse files
Files changed (1) hide show
  1. genai_client.py +250 -0
genai_client.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Compatibility wrapper for google.generativeai surfaces.
3
+
4
+ - Reads API key from env var GOOGLE_API_KEY
5
+ - Detects older global-style surfaces (genai.configure / GenerativeModel / genai.responses)
6
+ and newer client-style surfaces (NewClient / genai.Client / client.responses)
7
+ - Exposes:
8
+ - create_response(**kwargs) -> response object (SDK response)
9
+ - upload_file(path, mime_type=None) -> file handle/object for use with responses/files
10
+ - set_model_default(model_name)
11
+ - Minimal surface: adapt your existing calls to import these helpers and call them.
12
+ """
13
+
14
+ import os
15
+ import logging
16
+ from pathlib import Path
17
+ from typing import Optional, Any, Dict
18
+
19
+ logger = logging.getLogger(__name__)
20
+ logger.setLevel(logging.INFO)
21
+
22
+ # Load api key from env
23
+ API_KEY = os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_GENAI_API_KEY")
24
+ if not API_KEY:
25
+ logger.warning("No GOOGLE_API_KEY (or GEMINI_API_KEY) found in environment. Set it before calling the API.")
26
+
27
+ # Import SDK and detect available surfaces
28
+ try:
29
+ import google.generativeai as genai
30
+ except Exception as e:
31
+ raise RuntimeError("google.generativeai import failed: " + str(e))
32
+
33
+ def _list_attrs():
34
+ return [k for k in dir(genai) if not k.startswith("_")]
35
+
36
+ _HAS_RESPONSES = hasattr(genai, "responses") or hasattr(genai, "Responses")
37
+ _HAS_NEWCLIENT = hasattr(genai, "NewClient") or hasattr(genai, "Client") or hasattr(genai, "client")
38
+ _HAS_GENERATIVEMODEL = hasattr(genai, "GenerativeModel") or hasattr(genai, "GenerativeModelV1")
39
+
40
+ logger.info("genai attributes detected: %s", _list_attrs())
41
+ logger.info("has responses: %s, has NewClient/Client: %s, has GenerativeModel: %s",
42
+ _HAS_RESPONSES, _HAS_NEWCLIENT, _HAS_GENERATIVEMODEL)
43
+
44
+ # Configure API key for older-style SDKs if possible
45
+ try:
46
+ # many 0.x versions use genai.configure or genai.api_key
47
+ if hasattr(genai, "configure"):
48
+ try:
49
+ genai.configure(api_key=API_KEY)
50
+ except Exception:
51
+ # some versions expect env var only
52
+ pass
53
+ elif hasattr(genai, "api_key"):
54
+ try:
55
+ genai.api_key = API_KEY
56
+ except Exception:
57
+ pass
58
+ except Exception:
59
+ pass
60
+
61
+ # Create a client when available (newer unified client)
62
+ _client = None
63
+ if _HAS_NEWCLIENT:
64
+ try:
65
+ if hasattr(genai, "NewClient"):
66
+ _client = genai.NewClient(api_key=API_KEY) if API_KEY else genai.NewClient()
67
+ elif hasattr(genai, "Client"):
68
+ # some variants expose Client
69
+ _client = genai.Client(api_key=API_KEY) if API_KEY else genai.Client()
70
+ elif hasattr(genai, "client"):
71
+ # wrapper namespace
72
+ _client = genai.client.NewClient(api_key=API_KEY) if API_KEY else genai.client.NewClient()
73
+ logger.info("Initialized genai client: %s", type(_client))
74
+ except Exception as e:
75
+ logger.warning("Failed to initialize genai client: %s", e)
76
+ _client = None
77
+
78
+ _DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "gemini-2.5-flash-lite")
79
+
80
+ def set_model_default(model_name: str):
81
+ global _DEFAULT_MODEL
82
+ _DEFAULT_MODEL = model_name
83
+
84
+ def _read_bytes(path: str) -> bytes:
85
+ return Path(path).read_bytes()
86
+
87
+ def upload_file(path: str, mime_type: Optional[str] = None) -> Any:
88
+ """
89
+ Upload a file to the SDK's file service if supported, or return a dict
90
+ that can be passed to create_response as raw bytes depending on the SDK.
91
+ Returns the uploaded file handle/object expected by the SDK in create_response.
92
+ """
93
+ path = str(path)
94
+ data = _read_bytes(path)
95
+ filename = Path(path).name
96
+
97
+ # New client: client.files.upload or client.files.create depending on version
98
+ if _client is not None:
99
+ # try common upload surfaces
100
+ for attr in ("files", "Files", "file_manager", "FileManager"):
101
+ files_container = getattr(_client, attr, None)
102
+ if files_container:
103
+ # try upload method names
104
+ for m in ("upload", "upload_file", "create", "create_file", "upload_bytes"):
105
+ if hasattr(files_container, m):
106
+ func = getattr(files_container, m)
107
+ try:
108
+ logger.info("Uploading via client.%s.%s", attr, m)
109
+ return func(file_bytes=data, filename=filename, mime_type=mime_type)
110
+ except TypeError:
111
+ # try alternative signature
112
+ try:
113
+ return func(path)
114
+ except Exception:
115
+ pass
116
+ except Exception as e:
117
+ logger.warning("client file upload failed: %s", e)
118
+ # fallback: some SDKs expect an object literal
119
+ try:
120
+ return {"name": filename, "bytes": data, "mime_type": mime_type}
121
+ except Exception:
122
+ pass
123
+
124
+ # Older/global-style surfaces: genai.Files or genai.upload_file or genai.responses.upload
125
+ for candidate in ("Files", "files", "upload_file", "upload", "responses"):
126
+ svc = getattr(genai, candidate, None)
127
+ if svc:
128
+ # try common methods
129
+ for m in ("upload", "upload_file", "create"):
130
+ if hasattr(svc, m):
131
+ try:
132
+ func = getattr(svc, m)
133
+ logger.info("Uploading via genai.%s.%s", candidate, m)
134
+ return func(path) # many older wrappers accept path
135
+ except Exception:
136
+ pass
137
+ # fallback return dict
138
+ try:
139
+ return {"name": filename, "bytes": data, "mime_type": mime_type}
140
+ except Exception:
141
+ pass
142
+
143
+ # Last resort: return bytes dict for manual usage in create_response
144
+ return {"name": filename, "bytes": data, "mime_type": mime_type}
145
+
146
+ def create_response(*, model: Optional[str] = None, input: Optional[Any] = None, **kwargs) -> Any:
147
+ """
148
+ Unified call to generate a response.
149
+
150
+ Common kwargs you'll want to pass: multimodal contents (contents), files, temperature, max_output_tokens
151
+ Example:
152
+ create_response(model="gemini-2.5-flash-lite", input="Hello")
153
+ create_response(contents=["Describe this image", {"name":"img.png","bytes":b"..."}])
154
+ """
155
+ model = model or _DEFAULT_MODEL
156
+ # Prefer client.responses.create / client.models.generate_content depending on SDK
157
+ if _client is not None:
158
+ # newer unified client surfaces
159
+ # many SDKs expose client.responses.create or client.models.generate_content
160
+ if hasattr(_client, "responses") and hasattr(_client.responses, "create"):
161
+ logger.info("Calling client.responses.create with model=%s", model)
162
+ payload = kwargs.copy()
163
+ payload.setdefault("model", model)
164
+ if input is not None:
165
+ # some SDKs accept 'input' or 'contents'
166
+ payload.setdefault("input", input)
167
+ return _client.responses.create(**payload)
168
+ if hasattr(_client, "models") and hasattr(_client.models, "generate_content"):
169
+ logger.info("Calling client.models.generate_content with model=%s", model)
170
+ payload = kwargs.copy()
171
+ payload.setdefault("model", model)
172
+ if input is not None:
173
+ payload.setdefault("contents", input)
174
+ return _client.models.generate_content(**payload)
175
+
176
+ # Global/genai module-style surfaces
177
+ if hasattr(genai, "responses") and hasattr(genai.responses, "create"):
178
+ logger.info("Calling genai.responses.create with model=%s", model)
179
+ payload = kwargs.copy()
180
+ payload.setdefault("model", model)
181
+ if input is not None:
182
+ payload.setdefault("input", input)
183
+ return genai.responses.create(**payload)
184
+
185
+ # Older GenerativeModel style
186
+ if _HAS_GENERATIVEMODEL:
187
+ try:
188
+ # instantiate model wrapper and call generate_content / start_chat style
189
+ ModelClass = getattr(genai, "GenerativeModel", None) or getattr(genai, "GenerativeModelV1", None)
190
+ if ModelClass:
191
+ m = ModelClass(model)
192
+ logger.info("Calling GenerativeModel.generate_content with model=%s", model)
193
+ if input is not None:
194
+ return m.generate_content(input if not isinstance(input, (list, tuple)) else input)
195
+ # pass kwargs as best-effort
196
+ if "contents" in kwargs:
197
+ return m.generate_content(kwargs["contents"])
198
+ return m.generate_content(kwargs)
199
+ except Exception as e:
200
+ logger.warning("GenerativeModel path failed: %s", e)
201
+
202
+ raise RuntimeError("No compatible google.generativeai surface found. See logs for details.")
203
+
204
+ # Small helper that returns a plain text result if present (convenience)
205
+ def extract_text(resp: Any) -> str:
206
+ """
207
+ Attempt to extract a concise text result from many SDK response shapes.
208
+ """
209
+ if resp is None:
210
+ return ""
211
+ # New unified: resp.output[0].content or resp.output_text
212
+ try:
213
+ if hasattr(resp, "output_text"):
214
+ return getattr(resp, "output_text") or ""
215
+ if hasattr(resp, "output"):
216
+ out = resp.output
217
+ # output may be list of dicts with 'content'/'text'
218
+ if isinstance(out, (list, tuple)) and len(out) > 0:
219
+ first = out[0]
220
+ if isinstance(first, dict):
221
+ return first.get("content") or first.get("text") or ""
222
+ if hasattr(first, "content"):
223
+ return getattr(first, "content") or ""
224
+ # older: resp.text
225
+ if hasattr(resp, "text"):
226
+ return getattr(resp, "text") or ""
227
+ # dict-like
228
+ if isinstance(resp, dict):
229
+ for k in ("output_text", "text", "result", "content", "output"):
230
+ if k in resp:
231
+ val = resp[k]
232
+ if isinstance(val, str):
233
+ return val
234
+ if isinstance(val, list) and len(val) > 0 and isinstance(val[0], str):
235
+ return val[0]
236
+ except Exception:
237
+ pass
238
+ # fallback: string representation
239
+ try:
240
+ return str(resp)
241
+ except Exception:
242
+ return ""
243
+
244
+ # Exported API
245
+ __all__ = [
246
+ "create_response",
247
+ "upload_file",
248
+ "set_model_default",
249
+ "extract_text",
250
+ ]