Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| import requests | |
| import json | |
| import codecs | |
| from pathlib import Path | |
| API_KEY = os.getenv("RUNPOD_API_KEY") | |
| ENDPOINT = os.getenv("RUNPOD_ENDPOINT") | |
| HEADERS = { | |
| "Authorization": f"Bearer {API_KEY}", | |
| "Content-Type": "application/json" | |
| } | |
| def format_messages_as_prompt(messages): | |
| """Convert messages list to a single prompt string for the model.""" | |
| parts = [] | |
| for message in messages: | |
| parts.append(f"{message['role'].capitalize()}: {message['content']}") | |
| parts.append("Assistant:") | |
| return "\n\n".join(parts) | |
| def run_prompt(prompt: str, max_tokens: int = 512, context_length: int = 8192) -> str: | |
| """Submit a prompt to the RunPod endpoint and get back a response string.""" | |
| payload = { | |
| "input": { | |
| "prompt": prompt, | |
| "max_tokens": max_tokens, | |
| "options": { | |
| "num_ctx": context_length, | |
| "num_predict": max_tokens | |
| } | |
| } | |
| } | |
| url = f"{ENDPOINT}/run" | |
| response = requests.post(url, headers=HEADERS, json=payload, timeout=60) | |
| if response.status_code != 200: | |
| raise RuntimeError(f"HTTP {response.status_code}: {response.text}") | |
| result = response.json() | |
| print(f"Initial response: {result}") | |
| # Check if it's a direct sync response | |
| if "success" in result: | |
| if result.get("success"): | |
| return result["response"] | |
| else: | |
| raise RuntimeError(f"Handler error: {result.get('error', 'Unknown error')}") | |
| # Handle async job response | |
| elif "id" in result: | |
| job_id = result["id"] | |
| print(f"[RunPod] Job started: {job_id}") | |
| # Poll for completion | |
| while True: | |
| status_url = f"{ENDPOINT}/status/{job_id}" | |
| status_response = requests.get(status_url, headers=HEADERS) | |
| if status_response.status_code != 200: | |
| raise RuntimeError(f"Status check failed: {status_response.status_code}") | |
| status_result = status_response.json() | |
| status = status_result.get("status") | |
| print(f"[RunPod] Status: {status}") | |
| if status == "COMPLETED": | |
| output = status_result.get("output", {}) | |
| if output.get("success"): | |
| return output["response"] | |
| else: | |
| raise RuntimeError(f"Job failed: {output.get('error', 'Unknown error')}") | |
| elif status == "FAILED": | |
| error_msg = status_result.get("error", "Job failed without specific error") | |
| raise RuntimeError(f"RunPod job failed: {error_msg}") | |
| elif status in ["IN_QUEUE", "IN_PROGRESS"]: | |
| print(f"[RunPod] Waiting... Status: {status}") | |
| time.sleep(3) # Wait 3 seconds before polling again | |
| else: | |
| raise RuntimeError(f"Unknown job status: {status}") | |
| else: | |
| raise RuntimeError(f"Unexpected response format: {result}") | |
| def clean_and_parse_json(raw_text: str): | |
| """Clean and parse model output into JSON.""" | |
| cleaned = raw_text.strip().strip("```json").strip("```").strip("'") | |
| try: | |
| return json.loads(cleaned) | |
| except json.JSONDecodeError: | |
| try: | |
| # Handle escaped quotes | |
| unescaped = codecs.decode(cleaned, 'unicode_escape') | |
| return json.loads(unescaped) | |
| except Exception as e: | |
| raise ValueError("Could not parse JSON output") from e | |