test2text / app /backend /runpod_client.py
davidepanza's picture
Update app/backend/runpod_client.py
82faabd verified
import os
import time
import requests
import json
import codecs
from pathlib import Path
API_KEY = os.getenv("RUNPOD_API_KEY")
ENDPOINT = os.getenv("RUNPOD_ENDPOINT")
HEADERS = {
"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json"
}
def format_messages_as_prompt(messages):
"""Convert messages list to a single prompt string for the model."""
parts = []
for message in messages:
parts.append(f"{message['role'].capitalize()}: {message['content']}")
parts.append("Assistant:")
return "\n\n".join(parts)
def run_prompt(prompt: str, max_tokens: int = 512, context_length: int = 8192) -> str:
"""Submit a prompt to the RunPod endpoint and get back a response string."""
payload = {
"input": {
"prompt": prompt,
"max_tokens": max_tokens,
"options": {
"num_ctx": context_length,
"num_predict": max_tokens
}
}
}
url = f"{ENDPOINT}/run"
response = requests.post(url, headers=HEADERS, json=payload, timeout=60)
if response.status_code != 200:
raise RuntimeError(f"HTTP {response.status_code}: {response.text}")
result = response.json()
print(f"Initial response: {result}")
# Check if it's a direct sync response
if "success" in result:
if result.get("success"):
return result["response"]
else:
raise RuntimeError(f"Handler error: {result.get('error', 'Unknown error')}")
# Handle async job response
elif "id" in result:
job_id = result["id"]
print(f"[RunPod] Job started: {job_id}")
# Poll for completion
while True:
status_url = f"{ENDPOINT}/status/{job_id}"
status_response = requests.get(status_url, headers=HEADERS)
if status_response.status_code != 200:
raise RuntimeError(f"Status check failed: {status_response.status_code}")
status_result = status_response.json()
status = status_result.get("status")
print(f"[RunPod] Status: {status}")
if status == "COMPLETED":
output = status_result.get("output", {})
if output.get("success"):
return output["response"]
else:
raise RuntimeError(f"Job failed: {output.get('error', 'Unknown error')}")
elif status == "FAILED":
error_msg = status_result.get("error", "Job failed without specific error")
raise RuntimeError(f"RunPod job failed: {error_msg}")
elif status in ["IN_QUEUE", "IN_PROGRESS"]:
print(f"[RunPod] Waiting... Status: {status}")
time.sleep(3) # Wait 3 seconds before polling again
else:
raise RuntimeError(f"Unknown job status: {status}")
else:
raise RuntimeError(f"Unexpected response format: {result}")
def clean_and_parse_json(raw_text: str):
"""Clean and parse model output into JSON."""
cleaned = raw_text.strip().strip("```json").strip("```").strip("'")
try:
return json.loads(cleaned)
except json.JSONDecodeError:
try:
# Handle escaped quotes
unescaped = codecs.decode(cleaned, 'unicode_escape')
return json.loads(unescaped)
except Exception as e:
raise ValueError("Could not parse JSON output") from e