import httpx import json import os import random from typing import AsyncGenerator, Optional, List, Any from ai_client import AIClient from dotenv import load_dotenv load_dotenv() class GitHubClient(AIClient): def __init__(self): self.api_keys = [] # Support GITHUB_API_KEY, GITHUB_API_KEY2, etc. for key in ["GITHUB_API_KEY", "GITHUB_API_KEY2", "GITHUB_TOKEN_GROK", "GITHUB_TOKEN_PROMPT"]: val = os.getenv(key) if val: self.api_keys.append(val.strip()) self.client = httpx.AsyncClient(timeout=60.0) def _get_api_key(self) -> str: if not self.api_keys: return "" return random.choice(self.api_keys) async def async_stream_request( self, model_id: str, prompt: str, system_prompt: Optional[str] = None, **kwargs: Any ) -> AsyncGenerator[str, None]: api_key = self._get_api_key() headers = { "Authorization": f"Bearer {api_key}", "Content-Type": "application/json" } # GitHub Models API is OpenAI-compatible messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "user", "content": prompt}) # Base URL for GitHub Models url = "https://models.inference.ai.azure.com/chat/completions" payload = { "model": model_id, "messages": messages, "stream": True, "temperature": kwargs.get("temperature", 1.0), "top_p": 1.0, "max_tokens": kwargs.get("max_tokens", 1000) } try: async with self.client.stream("POST", url, headers=headers, json=payload) as response: if response.status_code != 200: err_text = await response.aread() yield f"Error from GitHub ({response.status_code}): {err_text.decode()}" return async for line in response.aiter_lines(): if line.startswith("data: "): data_str = line[6:].strip() if data_str == "[DONE]": break try: data = json.loads(data_str) content = data.get("choices", [{}])[0].get("delta", {}).get("content", "") if content: yield content except Exception: continue except Exception as e: yield f"Connection error: {str(e)}" async def close(self) -> None: await self.client.aclose()