Spaces:
Runtime error
Runtime error
| """ | |
| PromptForge v4.0 β AI enhancement client. | |
| Upgrades: exponential-backoff retry, configurable model override, | |
| token-budget awareness, structured error types, and a | |
| health-check helper for each provider. | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import logging | |
| from typing import Optional | |
| import httpx | |
| logger = logging.getLogger("promptforge.ai_client") | |
| # ββ Defaults ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| HF_API_BASE = "https://api-inference.huggingface.co/models" | |
| HF_DEFAULT_MODEL = "mistralai/Mistral-7B-Instruct-v0.2" | |
| GOOGLE_AI_BASE = "https://generativelanguage.googleapis.com/v1beta" | |
| GOOGLE_DEFAULT_MODEL = "gemini-1.5-flash" | |
| _SYSTEM_PROMPT = ( | |
| "You are an expert prompt engineer specialising in Google AI Studio system prompts. " | |
| "Given a draft prompt, improve its clarity, specificity, structure, and effectiveness. " | |
| "Preserve all section headers (## ROLE, ## TASK, etc.). " | |
| "Return ONLY the improved prompt β no commentary, no preamble, no markdown fences." | |
| ) | |
| # ββ Retry helper ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def _post_with_retry( | |
| client: httpx.AsyncClient, | |
| url: str, | |
| payload: dict, | |
| headers: Optional[dict] = None, | |
| retries: int = 3, | |
| base_delay: float = 1.0, | |
| ) -> httpx.Response: | |
| """POST with exponential backoff on 429 / 503 responses.""" | |
| last_exc: Exception = RuntimeError("No attempts made") | |
| for attempt in range(retries): | |
| try: | |
| resp = await client.post(url, json=payload, headers=headers or {}) | |
| if resp.status_code in (429, 503): | |
| wait = base_delay * (2 ** attempt) | |
| logger.warning("Rate-limited (%d) β retrying in %.1fs", resp.status_code, wait) | |
| await asyncio.sleep(wait) | |
| continue | |
| return resp | |
| except (httpx.ConnectError, httpx.TimeoutException) as exc: | |
| last_exc = exc | |
| wait = base_delay * (2 ** attempt) | |
| logger.warning("Network error on attempt %d: %s β retrying in %.1fs", attempt + 1, exc, wait) | |
| await asyncio.sleep(wait) | |
| raise last_exc | |
| # ββ Hugging Face ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def enhance_with_huggingface( | |
| raw_prompt: str, | |
| api_key: str, | |
| model: str = HF_DEFAULT_MODEL, | |
| max_new_tokens: int = 1024, | |
| ) -> str: | |
| payload = { | |
| "inputs": ( | |
| f"<s>[INST] {_SYSTEM_PROMPT}\n\n" | |
| f"DRAFT PROMPT:\n{raw_prompt}\n\n" | |
| f"IMPROVED PROMPT: [/INST]" | |
| ), | |
| "parameters": { | |
| "max_new_tokens": max_new_tokens, | |
| "return_full_text": False, | |
| "temperature": 0.3, | |
| "repetition_penalty": 1.1, | |
| }, | |
| } | |
| headers = {"Authorization": f"Bearer {api_key}"} | |
| url = f"{HF_API_BASE}/{model}" | |
| try: | |
| async with httpx.AsyncClient(timeout=45.0) as client: | |
| resp = await _post_with_retry(client, url, payload, headers) | |
| resp.raise_for_status() | |
| data = resp.json() | |
| if isinstance(data, list) and data: | |
| enhanced = data[0].get("generated_text", "").strip() | |
| if enhanced and len(enhanced) > 100: | |
| return enhanced | |
| logger.warning("HF returned empty or too-short text; using original.") | |
| return raw_prompt | |
| except Exception as exc: | |
| logger.warning("HuggingFace enhancement failed: %s", exc) | |
| return raw_prompt | |
| # ββ Google Gemini βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def enhance_with_google( | |
| raw_prompt: str, | |
| api_key: str, | |
| model: str = GOOGLE_DEFAULT_MODEL, | |
| ) -> str: | |
| url = f"{GOOGLE_AI_BASE}/models/{model}:generateContent?key={api_key}" | |
| payload = { | |
| "contents": [ | |
| { | |
| "parts": [ | |
| {"text": f"{_SYSTEM_PROMPT}\n\nDRAFT PROMPT:\n{raw_prompt}"} | |
| ] | |
| } | |
| ], | |
| "generationConfig": { | |
| "maxOutputTokens": 2048, | |
| "temperature": 0.3, | |
| "topP": 0.9, | |
| }, | |
| "safetySettings": [ | |
| {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_ONLY_HIGH"}, | |
| ], | |
| } | |
| try: | |
| async with httpx.AsyncClient(timeout=45.0) as client: | |
| resp = await _post_with_retry(client, url, payload) | |
| resp.raise_for_status() | |
| data = resp.json() | |
| candidates = data.get("candidates", []) | |
| if candidates: | |
| parts = candidates[0].get("content", {}).get("parts", []) | |
| if parts: | |
| enhanced = parts[0].get("text", "").strip() | |
| if enhanced and len(enhanced) > 100: | |
| return enhanced | |
| logger.warning("Google returned empty or too-short text; using original.") | |
| return raw_prompt | |
| except Exception as exc: | |
| logger.warning("Google AI enhancement failed: %s", exc) | |
| return raw_prompt | |
| # ββ Dispatcher ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def enhance_prompt( | |
| raw_prompt: str, | |
| provider: str, | |
| api_key: Optional[str], | |
| model_override: Optional[str] = None, | |
| ) -> tuple[str, str]: | |
| """ | |
| Dispatch to the correct provider. | |
| Returns (enhanced_text, notes_string). | |
| """ | |
| if not api_key: | |
| return raw_prompt, "No API key provided β skipping AI enhancement." | |
| if provider == "huggingface": | |
| model = model_override or HF_DEFAULT_MODEL | |
| enhanced = await enhance_with_huggingface(raw_prompt, api_key, model=model) | |
| notes = f"Enhanced via Hugging Face ({model})." | |
| elif provider == "google": | |
| model = model_override or GOOGLE_DEFAULT_MODEL | |
| enhanced = await enhance_with_google(raw_prompt, api_key, model=model) | |
| notes = f"Enhanced via Google Gemini ({model})." | |
| else: | |
| return raw_prompt, "Provider 'none' β no AI enhancement applied." | |
| if enhanced == raw_prompt: | |
| notes += " (Enhancement returned identical text β possible model or quota issue.)" | |
| return enhanced, notes | |
| # ββ Provider health-check βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def check_hf_key(api_key: str) -> bool: | |
| """Return True if the HF key appears valid (quick whoami probe).""" | |
| try: | |
| async with httpx.AsyncClient(timeout=8.0) as client: | |
| r = await client.get( | |
| "https://huggingface.co/api/whoami", | |
| headers={"Authorization": f"Bearer {api_key}"}, | |
| ) | |
| return r.status_code == 200 | |
| except Exception: | |
| return False | |
| async def check_google_key(api_key: str) -> bool: | |
| """Return True if the Google key appears valid (list models probe).""" | |
| try: | |
| async with httpx.AsyncClient(timeout=8.0) as client: | |
| r = await client.get( | |
| f"{GOOGLE_AI_BASE}/models?key={api_key}" | |
| ) | |
| return r.status_code == 200 | |
| except Exception: | |
| return False | |