File size: 28,702 Bytes
8e0dd55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
"""OpenRouter ModelClient integration."""

from typing import Dict, Sequence, Optional, Any, List
import logging
import json
import aiohttp
import requests
from requests.exceptions import RequestException, Timeout

from adalflow.core.model_client import ModelClient
from adalflow.core.types import (
    CompletionUsage,
    ModelType,
    GeneratorOutput,
)

log = logging.getLogger(__name__)

class OpenRouterClient(ModelClient):
    __doc__ = r"""A component wrapper for the OpenRouter API client.

    OpenRouter provides a unified API that gives access to hundreds of AI models through a single endpoint.
    The API is compatible with OpenAI's API format with a few small differences.

    Visit https://openrouter.ai/docs for more details.

    Example:
        ```python
        from api.openrouter_client import OpenRouterClient

        client = OpenRouterClient()
        generator = adal.Generator(
            model_client=client,
            model_kwargs={"model": "openai/gpt-4o"}
        )
        ```
    """

    def __init__(self, *args, **kwargs) -> None:
        """Initialize the OpenRouter client."""
        super().__init__(*args, **kwargs)
        self.sync_client = self.init_sync_client()
        self.async_client = None  # Initialize async client only when needed

    def init_sync_client(self):
        """Initialize the synchronous OpenRouter client."""
        from api.config import OPENROUTER_API_KEY
        api_key = OPENROUTER_API_KEY
        if not api_key:
            log.warning("OPENROUTER_API_KEY not configured")

        # OpenRouter doesn't have a dedicated client library, so we'll use requests directly
        return {
            "api_key": api_key,
            "base_url": "https://openrouter.ai/api/v1"
        }

    def init_async_client(self):
        """Initialize the asynchronous OpenRouter client."""
        from api.config import OPENROUTER_API_KEY
        api_key = OPENROUTER_API_KEY
        if not api_key:
            log.warning("OPENROUTER_API_KEY not configured")

        # For async, we'll use aiohttp
        return {
            "api_key": api_key,
            "base_url": "https://openrouter.ai/api/v1"
        }

    def convert_inputs_to_api_kwargs(
        self, input: Any, model_kwargs: Dict = None, model_type: ModelType = None
    ) -> Dict:
        """Convert AdalFlow inputs to OpenRouter API format."""
        model_kwargs = model_kwargs or {}

        if model_type == ModelType.LLM:
            # Handle LLM generation
            messages = []

            # Convert input to messages format if it's a string
            if isinstance(input, str):
                messages = [{"role": "user", "content": input}]
            elif isinstance(input, list) and all(isinstance(msg, dict) for msg in input):
                messages = input
            else:
                raise ValueError(f"Unsupported input format for OpenRouter: {type(input)}")

            # For debugging
            log.info(f"Messages for OpenRouter: {messages}")

            api_kwargs = {
                "messages": messages,
                **model_kwargs
            }

            # Ensure model is specified
            if "model" not in api_kwargs:
                api_kwargs["model"] = "openai/gpt-3.5-turbo"

            return api_kwargs

        elif model_type == ModelType.EMBEDDING:
            # OpenRouter doesn't support embeddings directly
            # We could potentially use a specific model through OpenRouter for embeddings
            # but for now, we'll raise an error
            raise NotImplementedError("OpenRouter client does not support embeddings yet")

        else:
            raise ValueError(f"Unsupported model type: {model_type}")

    async def acall(self, api_kwargs: Dict = None, model_type: ModelType = None) -> Any:
        """Make an asynchronous call to the OpenRouter API."""
        if not self.async_client:
            self.async_client = self.init_async_client()

        # Check if API key is set
        if not self.async_client.get("api_key"):
            error_msg = "OPENROUTER_API_KEY not configured. Please set this environment variable to use OpenRouter."
            log.error(error_msg)
            # Instead of raising an exception, return a generator that yields the error message
            # This allows the error to be displayed to the user in the streaming response
            async def error_generator():
                yield error_msg
            return error_generator()

        api_kwargs = api_kwargs or {}

        if model_type == ModelType.LLM:
            # Prepare headers
            headers = {
                "Authorization": f"Bearer {self.async_client['api_key']}",
                "Content-Type": "application/json",
                "HTTP-Referer": "https://github.com/AsyncFuncAI/deepwiki-open",  # Optional
                "X-Title": "DeepWiki"  # Optional
            }

            # Always use non-streaming mode for OpenRouter
            api_kwargs["stream"] = False

            # Make the API call
            try:
                log.info(f"Making async OpenRouter API call to {self.async_client['base_url']}/chat/completions")
                log.info(f"Request headers: {headers}")
                log.info(f"Request body: {api_kwargs}")

                async with aiohttp.ClientSession() as session:
                    try:
                        async with session.post(
                            f"{self.async_client['base_url']}/chat/completions",
                            headers=headers,
                            json=api_kwargs,
                            timeout=60
                        ) as response:
                            if response.status != 200:
                                error_text = await response.text()
                                log.error(f"OpenRouter API error ({response.status}): {error_text}")

                                # Return a generator that yields the error message
                                async def error_response_generator():
                                    yield f"OpenRouter API error ({response.status}): {error_text}"
                                return error_response_generator()

                            # Get the full response
                            data = await response.json()
                            log.info(f"Received response from OpenRouter: {data}")

                            # Create a generator that yields the content
                            async def content_generator():
                                if "choices" in data and len(data["choices"]) > 0:
                                    choice = data["choices"][0]
                                    if "message" in choice and "content" in choice["message"]:
                                        content = choice["message"]["content"]
                                        log.info("Successfully retrieved response")

                                        # Check if the content is XML and ensure it's properly formatted
                                        if content.strip().startswith("<") and ">" in content:
                                            # It's likely XML, let's make sure it's properly formatted
                                            try:
                                                # Extract the XML content
                                                xml_content = content

                                                # Check if it's a wiki_structure XML
                                                if "<wiki_structure>" in xml_content:
                                                    log.info("Found wiki_structure XML, ensuring proper format")

                                                    # Extract just the wiki_structure XML
                                                    import re
                                                    wiki_match = re.search(r'<wiki_structure>[\s\S]*?<\/wiki_structure>', xml_content)
                                                    if wiki_match:
                                                        # Get the raw XML
                                                        raw_xml = wiki_match.group(0)

                                                        # Clean the XML by removing any leading/trailing whitespace
                                                        # and ensuring it's properly formatted
                                                        clean_xml = raw_xml.strip()

                                                        # Try to fix common XML issues
                                                        try:
                                                            # Replace problematic characters in XML
                                                            fixed_xml = clean_xml

                                                            # Replace & with &amp; if not already part of an entity
                                                            fixed_xml = re.sub(r'&(?!amp;|lt;|gt;|apos;|quot;)', '&amp;', fixed_xml)

                                                            # Fix other common XML issues
                                                            fixed_xml = fixed_xml.replace('</', '</').replace('  >', '>')

                                                            # Try to parse the fixed XML
                                                            from xml.dom.minidom import parseString
                                                            dom = parseString(fixed_xml)

                                                            # Get the pretty-printed XML with proper indentation
                                                            pretty_xml = dom.toprettyxml()

                                                            # Remove XML declaration
                                                            if pretty_xml.startswith('<?xml'):
                                                                pretty_xml = pretty_xml[pretty_xml.find('?>')+2:].strip()

                                                            log.info(f"Extracted and validated XML: {pretty_xml[:100]}...")
                                                            yield pretty_xml
                                                        except Exception as xml_parse_error:
                                                            log.warning(f"XML validation failed: {str(xml_parse_error)}, using raw XML")

                                                            # If XML validation fails, try a more aggressive approach
                                                            try:
                                                                # Use regex to extract just the structure without any problematic characters
                                                                import re

                                                                # Extract the basic structure
                                                                structure_match = re.search(r'<wiki_structure>(.*?)</wiki_structure>', clean_xml, re.DOTALL)
                                                                if structure_match:
                                                                    structure = structure_match.group(1).strip()

                                                                    # Rebuild a clean XML structure
                                                                    clean_structure = "<wiki_structure>\n"

                                                                    # Extract title
                                                                    title_match = re.search(r'<title>(.*?)</title>', structure, re.DOTALL)
                                                                    if title_match:
                                                                        title = title_match.group(1).strip()
                                                                        clean_structure += f"  <title>{title}</title>\n"

                                                                    # Extract description
                                                                    desc_match = re.search(r'<description>(.*?)</description>', structure, re.DOTALL)
                                                                    if desc_match:
                                                                        desc = desc_match.group(1).strip()
                                                                        clean_structure += f"  <description>{desc}</description>\n"

                                                                    # Add pages section
                                                                    clean_structure += "  <pages>\n"

                                                                    # Extract pages
                                                                    pages = re.findall(r'<page id="(.*?)">(.*?)</page>', structure, re.DOTALL)
                                                                    for page_id, page_content in pages:
                                                                        clean_structure += f'    <page id="{page_id}">\n'

                                                                        # Extract page title
                                                                        page_title_match = re.search(r'<title>(.*?)</title>', page_content, re.DOTALL)
                                                                        if page_title_match:
                                                                            page_title = page_title_match.group(1).strip()
                                                                            clean_structure += f"      <title>{page_title}</title>\n"

                                                                        # Extract page description
                                                                        page_desc_match = re.search(r'<description>(.*?)</description>', page_content, re.DOTALL)
                                                                        if page_desc_match:
                                                                            page_desc = page_desc_match.group(1).strip()
                                                                            clean_structure += f"      <description>{page_desc}</description>\n"

                                                                        # Extract importance
                                                                        importance_match = re.search(r'<importance>(.*?)</importance>', page_content, re.DOTALL)
                                                                        if importance_match:
                                                                            importance = importance_match.group(1).strip()
                                                                            clean_structure += f"      <importance>{importance}</importance>\n"

                                                                        # Extract relevant files
                                                                        clean_structure += "      <relevant_files>\n"
                                                                        file_paths = re.findall(r'<file_path>(.*?)</file_path>', page_content, re.DOTALL)
                                                                        for file_path in file_paths:
                                                                            clean_structure += f"        <file_path>{file_path.strip()}</file_path>\n"
                                                                        clean_structure += "      </relevant_files>\n"

                                                                        # Extract related pages
                                                                        clean_structure += "      <related_pages>\n"
                                                                        related_pages = re.findall(r'<related>(.*?)</related>', page_content, re.DOTALL)
                                                                        for related in related_pages:
                                                                            clean_structure += f"        <related>{related.strip()}</related>\n"
                                                                        clean_structure += "      </related_pages>\n"

                                                                        clean_structure += "    </page>\n"

                                                                    clean_structure += "  </pages>\n</wiki_structure>"

                                                                    log.info("Successfully rebuilt clean XML structure")
                                                                    yield clean_structure
                                                                else:
                                                                    log.warning("Could not extract wiki structure, using raw XML")
                                                                    yield clean_xml
                                                            except Exception as rebuild_error:
                                                                log.warning(f"Failed to rebuild XML: {str(rebuild_error)}, using raw XML")
                                                                yield clean_xml
                                                    else:
                                                        # If we can't extract it, just yield the original content
                                                        log.warning("Could not extract wiki_structure XML, yielding original content")
                                                        yield xml_content
                                                else:
                                                    # For other XML content, just yield it as is
                                                    yield content
                                            except Exception as xml_error:
                                                log.error(f"Error processing XML content: {str(xml_error)}")
                                                yield content
                                        else:
                                            # Not XML, just yield the content
                                            yield content
                                    else:
                                        log.error(f"Unexpected response format: {data}")
                                        yield "Error: Unexpected response format from OpenRouter API"
                                else:
                                    log.error(f"No choices in response: {data}")
                                    yield "Error: No response content from OpenRouter API"

                            return content_generator()
                    except aiohttp.ClientError as e:
                        e_client = e
                        log.error(f"Connection error with OpenRouter API: {str(e_client)}")

                        # Return a generator that yields the error message
                        async def connection_error_generator():
                            yield f"Connection error with OpenRouter API: {str(e_client)}. Please check your internet connection and that the OpenRouter API is accessible."
                        return connection_error_generator()

            except RequestException as e:
                e_req = e
                log.error(f"Error calling OpenRouter API asynchronously: {str(e_req)}")

                # Return a generator that yields the error message
                async def request_error_generator():
                    yield f"Error calling OpenRouter API: {str(e_req)}"
                return request_error_generator()

            except Exception as e:
                e_unexp = e
                log.error(f"Unexpected error calling OpenRouter API asynchronously: {str(e_unexp)}")

                # Return a generator that yields the error message
                async def unexpected_error_generator():
                    yield f"Unexpected error calling OpenRouter API: {str(e_unexp)}"
                return unexpected_error_generator()

        else:
            error_msg = f"Unsupported model type: {model_type}"
            log.error(error_msg)

            # Return a generator that yields the error message
            async def model_type_error_generator():
                yield error_msg
            return model_type_error_generator()

    def _process_completion_response(self, data: Dict) -> GeneratorOutput:
        """Process a non-streaming completion response from OpenRouter."""
        try:
            # Extract the completion text from the response
            if not data.get("choices"):
                raise ValueError(f"No choices in OpenRouter response: {data}")

            choice = data["choices"][0]

            if "message" in choice:
                content = choice["message"].get("content", "")
            elif "text" in choice:
                content = choice.get("text", "")
            else:
                raise ValueError(f"Unexpected response format from OpenRouter: {choice}")

            # Extract usage information if available
            usage = None
            if "usage" in data:
                usage = CompletionUsage(
                    prompt_tokens=data["usage"].get("prompt_tokens", 0),
                    completion_tokens=data["usage"].get("completion_tokens", 0),
                    total_tokens=data["usage"].get("total_tokens", 0)
                )

            # Create and return the GeneratorOutput
            return GeneratorOutput(
                data=content,
                usage=usage,
                raw_response=data
            )

        except Exception as e_proc:
            log.error(f"Error processing OpenRouter completion response: {str(e_proc)}")
            raise

    def _process_streaming_response(self, response):
        """Process a streaming response from OpenRouter."""
        try:
            log.info("Starting to process streaming response from OpenRouter")
            buffer = ""

            for chunk in response.iter_content(chunk_size=1024, decode_unicode=True):
                try:
                    # Add chunk to buffer
                    buffer += chunk

                    # Process complete lines in the buffer
                    while '\n' in buffer:
                        line, buffer = buffer.split('\n', 1)
                        line = line.strip()

                        if not line:
                            continue

                        log.debug(f"Processing line: {line}")

                        # Skip SSE comments (lines starting with :)
                        if line.startswith(':'):
                            log.debug(f"Skipping SSE comment: {line}")
                            continue

                        if line.startswith("data: "):
                            data = line[6:]  # Remove "data: " prefix

                            # Check for stream end
                            if data == "[DONE]":
                                log.info("Received [DONE] marker")
                                break

                            try:
                                data_obj = json.loads(data)
                                log.debug(f"Parsed JSON data: {data_obj}")

                                # Extract content from delta
                                if "choices" in data_obj and len(data_obj["choices"]) > 0:
                                    choice = data_obj["choices"][0]

                                    if "delta" in choice and "content" in choice["delta"] and choice["delta"]["content"]:
                                        content = choice["delta"]["content"]
                                        log.debug(f"Yielding delta content: {content}")
                                        yield content
                                    elif "text" in choice:
                                        log.debug(f"Yielding text content: {choice['text']}")
                                        yield choice["text"]
                                    else:
                                        log.debug(f"No content found in choice: {choice}")
                                else:
                                    log.debug(f"No choices found in data: {data_obj}")

                            except json.JSONDecodeError:
                                log.warning(f"Failed to parse SSE data: {data}")
                                continue
                except Exception as e_chunk:
                    log.error(f"Error processing streaming chunk: {str(e_chunk)}")
                    yield f"Error processing response chunk: {str(e_chunk)}"
        except Exception as e_stream:
            log.error(f"Error in streaming response: {str(e_stream)}")
            yield f"Error in streaming response: {str(e_stream)}"

    async def _process_async_streaming_response(self, response):
        """Process an asynchronous streaming response from OpenRouter."""
        buffer = ""
        try:
            log.info("Starting to process async streaming response from OpenRouter")
            async for chunk in response.content:
                try:
                    # Convert bytes to string and add to buffer
                    if isinstance(chunk, bytes):
                        chunk_str = chunk.decode('utf-8')
                    else:
                        chunk_str = str(chunk)

                    buffer += chunk_str

                    # Process complete lines in the buffer
                    while '\n' in buffer:
                        line, buffer = buffer.split('\n', 1)
                        line = line.strip()

                        if not line:
                            continue

                        log.debug(f"Processing line: {line}")

                        # Skip SSE comments (lines starting with :)
                        if line.startswith(':'):
                            log.debug(f"Skipping SSE comment: {line}")
                            continue

                        if line.startswith("data: "):
                            data = line[6:]  # Remove "data: " prefix

                            # Check for stream end
                            if data == "[DONE]":
                                log.info("Received [DONE] marker")
                                break

                            try:
                                data_obj = json.loads(data)
                                log.debug(f"Parsed JSON data: {data_obj}")

                                # Extract content from delta
                                if "choices" in data_obj and len(data_obj["choices"]) > 0:
                                    choice = data_obj["choices"][0]

                                    if "delta" in choice and "content" in choice["delta"] and choice["delta"]["content"]:
                                        content = choice["delta"]["content"]
                                        log.debug(f"Yielding delta content: {content}")
                                        yield content
                                    elif "text" in choice:
                                        log.debug(f"Yielding text content: {choice['text']}")
                                        yield choice["text"]
                                    else:
                                        log.debug(f"No content found in choice: {choice}")
                                else:
                                    log.debug(f"No choices found in data: {data_obj}")

                            except json.JSONDecodeError:
                                log.warning(f"Failed to parse SSE data: {data}")
                                continue
                except Exception as e_chunk:
                    log.error(f"Error processing streaming chunk: {str(e_chunk)}")
                    yield f"Error processing response chunk: {str(e_chunk)}"
        except Exception as e_stream:
            log.error(f"Error in async streaming response: {str(e_stream)}")
            yield f"Error in streaming response: {str(e_stream)}"