PeterPinetree commited on
Commit
8f0d448
·
1 Parent(s): f366b93

Switch to openai-community/gpt2 with gpt2-medium fallback for better serverless inference ability

Browse files
Files changed (1) hide show
  1. app.py +29 -9
app.py CHANGED
@@ -11,7 +11,7 @@ load_dotenv()
11
 
12
  # Configuration
13
  API_BASE = "https://api-inference.huggingface.co/models/"
14
- MODEL_ID = "gpt2"
15
  HF_TOKEN = os.getenv('HF_NEXT_TOKEN_PREDICTOR_TOKEN', '')
16
 
17
  def show_token(token: str) -> str:
@@ -53,15 +53,35 @@ def predict_next_token(text: str, top_k: int = 10, hide_punctuation: bool = Fals
53
 
54
  response = requests.post(url, headers=headers, json=payload, timeout=30)
55
 
 
 
 
56
  if not response.ok:
57
- error_msg = f"API Error: {response.status_code} for model {MODEL_ID}"
58
- try:
59
- error_detail = response.json()
60
- if 'error' in error_detail:
61
- error_msg += f" - {error_detail['error']}"
62
- except:
63
- error_msg += f" - {response.text[:200]}"
64
- return error_msg, ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  result = response.json()
67
  prediction_time = int((time.time() - start_time) * 1000)
 
11
 
12
  # Configuration
13
  API_BASE = "https://api-inference.huggingface.co/models/"
14
+ MODEL_ID = "openai-community/gpt2"
15
  HF_TOKEN = os.getenv('HF_NEXT_TOKEN_PREDICTOR_TOKEN', '')
16
 
17
  def show_token(token: str) -> str:
 
53
 
54
  response = requests.post(url, headers=headers, json=payload, timeout=30)
55
 
56
+ # Debug logging
57
+ print(f"API URL: {url}")
58
+ print(f"Response status: {response.status_code}")
59
  if not response.ok:
60
+ print(f"Response text: {response.text}")
61
+
62
+ if not response.ok:
63
+ # Try GPT-2 Medium as fallback if the main model fails
64
+ if MODEL_ID == "openai-community/gpt2":
65
+ print(f"Main model failed, trying GPT-2 Medium fallback...")
66
+ fallback_url = f"{API_BASE}openai-community/gpt2-medium"
67
+ fallback_response = requests.post(fallback_url, headers=headers, json=payload, timeout=30)
68
+ print(f"Fallback response status: {fallback_response.status_code}")
69
+ if fallback_response.ok:
70
+ response = fallback_response
71
+ print("✅ Fallback successful!")
72
+ else:
73
+ print(f"Fallback also failed: {fallback_response.text[:100]}")
74
+
75
+ # If still not ok after fallback attempt
76
+ if not response.ok:
77
+ error_msg = f"API Error: {response.status_code} for model {MODEL_ID}"
78
+ try:
79
+ error_detail = response.json()
80
+ if 'error' in error_detail:
81
+ error_msg += f" - {error_detail['error']}"
82
+ except:
83
+ error_msg += f" - {response.text[:200]}"
84
+ return error_msg, ""
85
 
86
  result = response.json()
87
  prediction_time = int((time.time() - start_time) * 1000)