promptforge / backend /ai_client.py
Really-amin's picture
Upload PromptForge v1.0 β€” Structured prompt generator for Google AI Studio
7732582 verified
"""
PromptForge v4.0 β€” AI enhancement client.
Upgrades: exponential-backoff retry, configurable model override,
token-budget awareness, structured error types, and a
health-check helper for each provider.
"""
from __future__ import annotations
import asyncio
import logging
from typing import Optional
import httpx
logger = logging.getLogger("promptforge.ai_client")
# ── Defaults ──────────────────────────────────────────────────────────────────
HF_API_BASE = "https://api-inference.huggingface.co/models"
HF_DEFAULT_MODEL = "mistralai/Mistral-7B-Instruct-v0.2"
GOOGLE_AI_BASE = "https://generativelanguage.googleapis.com/v1beta"
GOOGLE_DEFAULT_MODEL = "gemini-1.5-flash"
_SYSTEM_PROMPT = (
"You are an expert prompt engineer specialising in Google AI Studio system prompts. "
"Given a draft prompt, improve its clarity, specificity, structure, and effectiveness. "
"Preserve all section headers (## ROLE, ## TASK, etc.). "
"Return ONLY the improved prompt β€” no commentary, no preamble, no markdown fences."
)
# ── Retry helper ──────────────────────────────────────────────────────────────
async def _post_with_retry(
client: httpx.AsyncClient,
url: str,
payload: dict,
headers: Optional[dict] = None,
retries: int = 3,
base_delay: float = 1.0,
) -> httpx.Response:
"""POST with exponential backoff on 429 / 503 responses."""
last_exc: Exception = RuntimeError("No attempts made")
for attempt in range(retries):
try:
resp = await client.post(url, json=payload, headers=headers or {})
if resp.status_code in (429, 503):
wait = base_delay * (2 ** attempt)
logger.warning("Rate-limited (%d) β€” retrying in %.1fs", resp.status_code, wait)
await asyncio.sleep(wait)
continue
return resp
except (httpx.ConnectError, httpx.TimeoutException) as exc:
last_exc = exc
wait = base_delay * (2 ** attempt)
logger.warning("Network error on attempt %d: %s β€” retrying in %.1fs", attempt + 1, exc, wait)
await asyncio.sleep(wait)
raise last_exc
# ── Hugging Face ──────────────────────────────────────────────────────────────
async def enhance_with_huggingface(
raw_prompt: str,
api_key: str,
model: str = HF_DEFAULT_MODEL,
max_new_tokens: int = 1024,
) -> str:
payload = {
"inputs": (
f"<s>[INST] {_SYSTEM_PROMPT}\n\n"
f"DRAFT PROMPT:\n{raw_prompt}\n\n"
f"IMPROVED PROMPT: [/INST]"
),
"parameters": {
"max_new_tokens": max_new_tokens,
"return_full_text": False,
"temperature": 0.3,
"repetition_penalty": 1.1,
},
}
headers = {"Authorization": f"Bearer {api_key}"}
url = f"{HF_API_BASE}/{model}"
try:
async with httpx.AsyncClient(timeout=45.0) as client:
resp = await _post_with_retry(client, url, payload, headers)
resp.raise_for_status()
data = resp.json()
if isinstance(data, list) and data:
enhanced = data[0].get("generated_text", "").strip()
if enhanced and len(enhanced) > 100:
return enhanced
logger.warning("HF returned empty or too-short text; using original.")
return raw_prompt
except Exception as exc:
logger.warning("HuggingFace enhancement failed: %s", exc)
return raw_prompt
# ── Google Gemini ─────────────────────────────────────────────────────────────
async def enhance_with_google(
raw_prompt: str,
api_key: str,
model: str = GOOGLE_DEFAULT_MODEL,
) -> str:
url = f"{GOOGLE_AI_BASE}/models/{model}:generateContent?key={api_key}"
payload = {
"contents": [
{
"parts": [
{"text": f"{_SYSTEM_PROMPT}\n\nDRAFT PROMPT:\n{raw_prompt}"}
]
}
],
"generationConfig": {
"maxOutputTokens": 2048,
"temperature": 0.3,
"topP": 0.9,
},
"safetySettings": [
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_ONLY_HIGH"},
],
}
try:
async with httpx.AsyncClient(timeout=45.0) as client:
resp = await _post_with_retry(client, url, payload)
resp.raise_for_status()
data = resp.json()
candidates = data.get("candidates", [])
if candidates:
parts = candidates[0].get("content", {}).get("parts", [])
if parts:
enhanced = parts[0].get("text", "").strip()
if enhanced and len(enhanced) > 100:
return enhanced
logger.warning("Google returned empty or too-short text; using original.")
return raw_prompt
except Exception as exc:
logger.warning("Google AI enhancement failed: %s", exc)
return raw_prompt
# ── Dispatcher ────────────────────────────────────────────────────────────────
async def enhance_prompt(
raw_prompt: str,
provider: str,
api_key: Optional[str],
model_override: Optional[str] = None,
) -> tuple[str, str]:
"""
Dispatch to the correct provider.
Returns (enhanced_text, notes_string).
"""
if not api_key:
return raw_prompt, "No API key provided β€” skipping AI enhancement."
if provider == "huggingface":
model = model_override or HF_DEFAULT_MODEL
enhanced = await enhance_with_huggingface(raw_prompt, api_key, model=model)
notes = f"Enhanced via Hugging Face ({model})."
elif provider == "google":
model = model_override or GOOGLE_DEFAULT_MODEL
enhanced = await enhance_with_google(raw_prompt, api_key, model=model)
notes = f"Enhanced via Google Gemini ({model})."
else:
return raw_prompt, "Provider 'none' β€” no AI enhancement applied."
if enhanced == raw_prompt:
notes += " (Enhancement returned identical text β€” possible model or quota issue.)"
return enhanced, notes
# ── Provider health-check ─────────────────────────────────────────────────────
async def check_hf_key(api_key: str) -> bool:
"""Return True if the HF key appears valid (quick whoami probe)."""
try:
async with httpx.AsyncClient(timeout=8.0) as client:
r = await client.get(
"https://huggingface.co/api/whoami",
headers={"Authorization": f"Bearer {api_key}"},
)
return r.status_code == 200
except Exception:
return False
async def check_google_key(api_key: str) -> bool:
"""Return True if the Google key appears valid (list models probe)."""
try:
async with httpx.AsyncClient(timeout=8.0) as client:
r = await client.get(
f"{GOOGLE_AI_BASE}/models?key={api_key}"
)
return r.status_code == 200
except Exception:
return False