Update app.py
Browse files
app.py
CHANGED
|
@@ -41,8 +41,13 @@ CONFIG = {
|
|
| 41 |
# hosted model) which can generate longer outputs faster than a CPU-bound local
|
| 42 |
# model. Set `HF_INFERENCE_MODEL` to choose the remote model (instruction-tuned
|
| 43 |
# model recommended).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
USE_REMOTE_LLM = False
|
| 45 |
-
REMOTE_LLM_MODEL = os.environ.get("HF_INFERENCE_MODEL", "
|
| 46 |
|
| 47 |
# Prefer the environment variable, but also allow a local token file for users
|
| 48 |
# who don't know how to set env vars. Create a file named `hf_token.txt` in the
|
|
@@ -69,13 +74,17 @@ def initialize_llm():
|
|
| 69 |
# If a remote HF Inference API key is provided, we won't instantiate a local
|
| 70 |
# heavy model; instead generation will be performed via the HTTP API.
|
| 71 |
global USE_REMOTE_LLM, REMOTE_LLM_MODEL
|
|
|
|
|
|
|
| 72 |
if USE_REMOTE_LLM:
|
| 73 |
-
logger.info(f"π Using remote Hugging Face Inference model: {REMOTE_LLM_MODEL}")
|
|
|
|
| 74 |
CONFIG["llm_model"] = REMOTE_LLM_MODEL
|
| 75 |
-
CONFIG["model_type"] = "
|
| 76 |
return None
|
| 77 |
|
| 78 |
-
|
|
|
|
| 79 |
model_name = "google/flan-t5-large"
|
| 80 |
|
| 81 |
try:
|
|
@@ -106,58 +115,82 @@ def remote_generate(prompt: str, max_new_tokens: int = 512, temperature: float =
|
|
| 106 |
"""Call the Hugging Face Inference API for remote generation. Requires
|
| 107 |
`HF_INFERENCE_API_KEY` env var to be set and a model name in
|
| 108 |
`REMOTE_LLM_MODEL`.
|
|
|
|
|
|
|
|
|
|
| 109 |
"""
|
| 110 |
if not HF_INFERENCE_API_KEY:
|
| 111 |
raise Exception("HF_INFERENCE_API_KEY not set for remote generation")
|
| 112 |
|
| 113 |
-
#
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
|
|
|
| 117 |
payload = {
|
| 118 |
"inputs": prompt,
|
| 119 |
"parameters": {
|
| 120 |
"max_new_tokens": max_new_tokens,
|
| 121 |
"temperature": temperature,
|
| 122 |
"top_p": top_p,
|
| 123 |
-
"
|
|
|
|
| 124 |
}
|
| 125 |
}
|
| 126 |
|
| 127 |
-
logger.info(f" β Remote inference
|
| 128 |
try:
|
| 129 |
-
r = requests.post(
|
| 130 |
except Exception as e:
|
| 131 |
-
logger.error(f" β Remote
|
| 132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
try:
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
logger.error(f" β Legacy endpoint request failed: {e2}")
|
| 138 |
return ""
|
| 139 |
|
| 140 |
if r.status_code != 200:
|
| 141 |
-
logger.error(f" β Remote inference error {r.status_code}: {r.text[:
|
| 142 |
return ""
|
| 143 |
|
| 144 |
result = r.json()
|
|
|
|
|
|
|
| 145 |
if isinstance(result, dict) and result.get("error"):
|
| 146 |
logger.error(f" β Remote inference returned error: {result.get('error')}")
|
| 147 |
return ""
|
| 148 |
|
| 149 |
-
#
|
|
|
|
|
|
|
| 150 |
if isinstance(result, list) and result:
|
| 151 |
-
#
|
| 152 |
first = result[0]
|
| 153 |
if isinstance(first, dict):
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
|
|
|
|
|
|
| 159 |
|
| 160 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
|
| 162 |
def initialize_embeddings():
|
| 163 |
logger.info("π Initializing embeddings model...")
|
|
|
|
| 41 |
# hosted model) which can generate longer outputs faster than a CPU-bound local
|
| 42 |
# model. Set `HF_INFERENCE_MODEL` to choose the remote model (instruction-tuned
|
| 43 |
# model recommended).
|
| 44 |
+
#
|
| 45 |
+
# PHI models are excellent lightweight instruction-following models:
|
| 46 |
+
# - microsoft/phi-2 (2.7B parameters, free inference)
|
| 47 |
+
# - microsoft/Phi-3-mini-4k-instruct (3.8B parameters, recommended)
|
| 48 |
+
# - microsoft/Phi-3-mini-128k-instruct (3.8B with longer context)
|
| 49 |
USE_REMOTE_LLM = False
|
| 50 |
+
REMOTE_LLM_MODEL = os.environ.get("HF_INFERENCE_MODEL", "microsoft/Phi-3-mini-4k-instruct")
|
| 51 |
|
| 52 |
# Prefer the environment variable, but also allow a local token file for users
|
| 53 |
# who don't know how to set env vars. Create a file named `hf_token.txt` in the
|
|
|
|
| 74 |
# If a remote HF Inference API key is provided, we won't instantiate a local
|
| 75 |
# heavy model; instead generation will be performed via the HTTP API.
|
| 76 |
global USE_REMOTE_LLM, REMOTE_LLM_MODEL
|
| 77 |
+
# For Hugging Face Spaces deployment: prefer remote PHI inference
|
| 78 |
+
# This avoids memory issues on CPU-only spaces and provides better performance
|
| 79 |
if USE_REMOTE_LLM:
|
| 80 |
+
logger.info(f"π Using remote Hugging Face Inference with PHI model: {REMOTE_LLM_MODEL}")
|
| 81 |
+
logger.info(f" β
PHI models are optimized for instruction-following and long-form generation")
|
| 82 |
CONFIG["llm_model"] = REMOTE_LLM_MODEL
|
| 83 |
+
CONFIG["model_type"] = "remote_phi"
|
| 84 |
return None
|
| 85 |
|
| 86 |
+
# Final fallback: attempt to initialize the free local T5 model (as before)
|
| 87 |
+
logger.info("π Initializing FREE local language model (fallback to T5)...")
|
| 88 |
model_name = "google/flan-t5-large"
|
| 89 |
|
| 90 |
try:
|
|
|
|
| 115 |
"""Call the Hugging Face Inference API for remote generation. Requires
|
| 116 |
`HF_INFERENCE_API_KEY` env var to be set and a model name in
|
| 117 |
`REMOTE_LLM_MODEL`.
|
| 118 |
+
|
| 119 |
+
PHI models work best with clear instruction formatting. This function
|
| 120 |
+
handles both the standard HF Inference API and PHI-specific response parsing.
|
| 121 |
"""
|
| 122 |
if not HF_INFERENCE_API_KEY:
|
| 123 |
raise Exception("HF_INFERENCE_API_KEY not set for remote generation")
|
| 124 |
|
| 125 |
+
# Use the HF Inference API endpoint (not router for better PHI compatibility)
|
| 126 |
+
api_url = f"https://api-inference.huggingface.co/models/{REMOTE_LLM_MODEL}"
|
| 127 |
+
headers = {"Authorization": f"Bearer {HF_INFERENCE_API_KEY}"}
|
| 128 |
+
|
| 129 |
+
# PHI models prefer simple parameters; avoid return_full_text which can cause issues
|
| 130 |
payload = {
|
| 131 |
"inputs": prompt,
|
| 132 |
"parameters": {
|
| 133 |
"max_new_tokens": max_new_tokens,
|
| 134 |
"temperature": temperature,
|
| 135 |
"top_p": top_p,
|
| 136 |
+
"do_sample": True,
|
| 137 |
+
"repetition_penalty": 1.1
|
| 138 |
}
|
| 139 |
}
|
| 140 |
|
| 141 |
+
logger.info(f" β Remote PHI inference to {REMOTE_LLM_MODEL} (tokens={max_new_tokens}, temp={temperature})")
|
| 142 |
try:
|
| 143 |
+
r = requests.post(api_url, headers=headers, json=payload, timeout=90)
|
| 144 |
except Exception as e:
|
| 145 |
+
logger.error(f" β Remote request failed: {e}")
|
| 146 |
+
return ""
|
| 147 |
+
|
| 148 |
+
if r.status_code == 503:
|
| 149 |
+
logger.warning(f" β οΈ Model loading (503), retrying in 5s...")
|
| 150 |
+
import time
|
| 151 |
+
time.sleep(5)
|
| 152 |
try:
|
| 153 |
+
r = requests.post(api_url, headers=headers, json=payload, timeout=90)
|
| 154 |
+
except Exception as e:
|
| 155 |
+
logger.error(f" β Retry failed: {e}")
|
|
|
|
| 156 |
return ""
|
| 157 |
|
| 158 |
if r.status_code != 200:
|
| 159 |
+
logger.error(f" β Remote inference error {r.status_code}: {r.text[:300]}")
|
| 160 |
return ""
|
| 161 |
|
| 162 |
result = r.json()
|
| 163 |
+
|
| 164 |
+
# Handle error responses
|
| 165 |
if isinstance(result, dict) and result.get("error"):
|
| 166 |
logger.error(f" β Remote inference returned error: {result.get('error')}")
|
| 167 |
return ""
|
| 168 |
|
| 169 |
+
# Parse the generated text from various response formats
|
| 170 |
+
generated_text = ""
|
| 171 |
+
|
| 172 |
if isinstance(result, list) and result:
|
| 173 |
+
# HF Inference API returns [{"generated_text": "..."}]
|
| 174 |
first = result[0]
|
| 175 |
if isinstance(first, dict):
|
| 176 |
+
generated_text = first.get("generated_text", "")
|
| 177 |
+
else:
|
| 178 |
+
generated_text = str(first)
|
| 179 |
+
elif isinstance(result, dict) and "generated_text" in result:
|
| 180 |
+
generated_text = result["generated_text"]
|
| 181 |
+
else:
|
| 182 |
+
generated_text = str(result)
|
| 183 |
|
| 184 |
+
# Clean up: PHI may return the prompt + completion, extract only new text
|
| 185 |
+
generated_text = generated_text.strip()
|
| 186 |
+
|
| 187 |
+
# If the response contains the original prompt, extract only the new completion
|
| 188 |
+
if prompt in generated_text:
|
| 189 |
+
# Find where the prompt ends and new generation begins
|
| 190 |
+
prompt_end = generated_text.find(prompt) + len(prompt)
|
| 191 |
+
generated_text = generated_text[prompt_end:].strip()
|
| 192 |
+
|
| 193 |
+
return generated_text
|
| 194 |
|
| 195 |
def initialize_embeddings():
|
| 196 |
logger.info("π Initializing embeddings model...")
|