Commit
·
8f0d448
1
Parent(s):
f366b93
Switch to openai-community/gpt2 with gpt2-medium fallback for better serverless inference ability
Browse files
app.py
CHANGED
|
@@ -11,7 +11,7 @@ load_dotenv()
|
|
| 11 |
|
| 12 |
# Configuration
|
| 13 |
API_BASE = "https://api-inference.huggingface.co/models/"
|
| 14 |
-
MODEL_ID = "gpt2"
|
| 15 |
HF_TOKEN = os.getenv('HF_NEXT_TOKEN_PREDICTOR_TOKEN', '')
|
| 16 |
|
| 17 |
def show_token(token: str) -> str:
|
|
@@ -53,15 +53,35 @@ def predict_next_token(text: str, top_k: int = 10, hide_punctuation: bool = Fals
|
|
| 53 |
|
| 54 |
response = requests.post(url, headers=headers, json=payload, timeout=30)
|
| 55 |
|
|
|
|
|
|
|
|
|
|
| 56 |
if not response.ok:
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
result = response.json()
|
| 67 |
prediction_time = int((time.time() - start_time) * 1000)
|
|
|
|
| 11 |
|
| 12 |
# Configuration
|
| 13 |
API_BASE = "https://api-inference.huggingface.co/models/"
|
| 14 |
+
MODEL_ID = "openai-community/gpt2"
|
| 15 |
HF_TOKEN = os.getenv('HF_NEXT_TOKEN_PREDICTOR_TOKEN', '')
|
| 16 |
|
| 17 |
def show_token(token: str) -> str:
|
|
|
|
| 53 |
|
| 54 |
response = requests.post(url, headers=headers, json=payload, timeout=30)
|
| 55 |
|
| 56 |
+
# Debug logging
|
| 57 |
+
print(f"API URL: {url}")
|
| 58 |
+
print(f"Response status: {response.status_code}")
|
| 59 |
if not response.ok:
|
| 60 |
+
print(f"Response text: {response.text}")
|
| 61 |
+
|
| 62 |
+
if not response.ok:
|
| 63 |
+
# Try GPT-2 Medium as fallback if the main model fails
|
| 64 |
+
if MODEL_ID == "openai-community/gpt2":
|
| 65 |
+
print(f"Main model failed, trying GPT-2 Medium fallback...")
|
| 66 |
+
fallback_url = f"{API_BASE}openai-community/gpt2-medium"
|
| 67 |
+
fallback_response = requests.post(fallback_url, headers=headers, json=payload, timeout=30)
|
| 68 |
+
print(f"Fallback response status: {fallback_response.status_code}")
|
| 69 |
+
if fallback_response.ok:
|
| 70 |
+
response = fallback_response
|
| 71 |
+
print("✅ Fallback successful!")
|
| 72 |
+
else:
|
| 73 |
+
print(f"Fallback also failed: {fallback_response.text[:100]}")
|
| 74 |
+
|
| 75 |
+
# If still not ok after fallback attempt
|
| 76 |
+
if not response.ok:
|
| 77 |
+
error_msg = f"API Error: {response.status_code} for model {MODEL_ID}"
|
| 78 |
+
try:
|
| 79 |
+
error_detail = response.json()
|
| 80 |
+
if 'error' in error_detail:
|
| 81 |
+
error_msg += f" - {error_detail['error']}"
|
| 82 |
+
except:
|
| 83 |
+
error_msg += f" - {response.text[:200]}"
|
| 84 |
+
return error_msg, ""
|
| 85 |
|
| 86 |
result = response.json()
|
| 87 |
prediction_time = int((time.time() - start_time) * 1000)
|