hamxaameer commited on
Commit
3a9a518
Β·
verified Β·
1 Parent(s): 7ba258a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -26
app.py CHANGED
@@ -41,8 +41,13 @@ CONFIG = {
41
  # hosted model) which can generate longer outputs faster than a CPU-bound local
42
  # model. Set `HF_INFERENCE_MODEL` to choose the remote model (instruction-tuned
43
  # model recommended).
 
 
 
 
 
44
  USE_REMOTE_LLM = False
45
- REMOTE_LLM_MODEL = os.environ.get("HF_INFERENCE_MODEL", "tiiuae/falcon-7b-instruct")
46
 
47
  # Prefer the environment variable, but also allow a local token file for users
48
  # who don't know how to set env vars. Create a file named `hf_token.txt` in the
@@ -69,13 +74,17 @@ def initialize_llm():
69
  # If a remote HF Inference API key is provided, we won't instantiate a local
70
  # heavy model; instead generation will be performed via the HTTP API.
71
  global USE_REMOTE_LLM, REMOTE_LLM_MODEL
 
 
72
  if USE_REMOTE_LLM:
73
- logger.info(f"πŸ”„ Using remote Hugging Face Inference model: {REMOTE_LLM_MODEL}")
 
74
  CONFIG["llm_model"] = REMOTE_LLM_MODEL
75
- CONFIG["model_type"] = "remote"
76
  return None
77
 
78
- logger.info("πŸ”„ Initializing FREE local language model...")
 
79
  model_name = "google/flan-t5-large"
80
 
81
  try:
@@ -106,58 +115,82 @@ def remote_generate(prompt: str, max_new_tokens: int = 512, temperature: float =
106
  """Call the Hugging Face Inference API for remote generation. Requires
107
  `HF_INFERENCE_API_KEY` env var to be set and a model name in
108
  `REMOTE_LLM_MODEL`.
 
 
 
109
  """
110
  if not HF_INFERENCE_API_KEY:
111
  raise Exception("HF_INFERENCE_API_KEY not set for remote generation")
112
 
113
- # New router endpoint is required by HF (replaces api-inference.huggingface.co)
114
- router_url = f"https://router.huggingface.co/models/{REMOTE_LLM_MODEL}"
115
- old_url = f"https://api-inference.huggingface.co/models/{REMOTE_LLM_MODEL}"
116
- headers = {"Authorization": f"Bearer {HF_INFERENCE_API_KEY}", "Accept": "application/json"}
 
117
  payload = {
118
  "inputs": prompt,
119
  "parameters": {
120
  "max_new_tokens": max_new_tokens,
121
  "temperature": temperature,
122
  "top_p": top_p,
123
- "return_full_text": False
 
124
  }
125
  }
126
 
127
- logger.info(f" β†’ Remote inference request to router {REMOTE_LLM_MODEL} (tokens={max_new_tokens}, temp={temperature})")
128
  try:
129
- r = requests.post(router_url, headers=headers, json=payload, timeout=120)
130
  except Exception as e:
131
- logger.error(f" βœ— Remote router request failed: {e}")
132
- # Try older endpoint as a fallback
 
 
 
 
 
133
  try:
134
- logger.info(" β†’ Attempting legacy api-inference endpoint as fallback")
135
- r = requests.post(old_url, headers=headers, json=payload, timeout=120)
136
- except Exception as e2:
137
- logger.error(f" βœ— Legacy endpoint request failed: {e2}")
138
  return ""
139
 
140
  if r.status_code != 200:
141
- logger.error(f" βœ— Remote inference error {r.status_code}: {r.text[:200]}")
142
  return ""
143
 
144
  result = r.json()
 
 
145
  if isinstance(result, dict) and result.get("error"):
146
  logger.error(f" βœ— Remote inference returned error: {result.get('error')}")
147
  return ""
148
 
149
- # The HF Inference API can return a list of generated outputs or text
 
 
150
  if isinstance(result, list) and result:
151
- # entries may be strings or dicts like {"generated_text": "..."}
152
  first = result[0]
153
  if isinstance(first, dict):
154
- return first.get("generated_text", "").strip()
155
- return str(first).strip()
156
-
157
- if isinstance(result, dict) and "generated_text" in result:
158
- return result["generated_text"].strip()
 
 
159
 
160
- return str(result).strip()
 
 
 
 
 
 
 
 
 
161
 
162
  def initialize_embeddings():
163
  logger.info("πŸ”„ Initializing embeddings model...")
 
41
  # hosted model) which can generate longer outputs faster than a CPU-bound local
42
  # model. Set `HF_INFERENCE_MODEL` to choose the remote model (instruction-tuned
43
  # model recommended).
44
+ #
45
+ # PHI models are excellent lightweight instruction-following models:
46
+ # - microsoft/phi-2 (2.7B parameters, free inference)
47
+ # - microsoft/Phi-3-mini-4k-instruct (3.8B parameters, recommended)
48
+ # - microsoft/Phi-3-mini-128k-instruct (3.8B with longer context)
49
  USE_REMOTE_LLM = False
50
+ REMOTE_LLM_MODEL = os.environ.get("HF_INFERENCE_MODEL", "microsoft/Phi-3-mini-4k-instruct")
51
 
52
  # Prefer the environment variable, but also allow a local token file for users
53
  # who don't know how to set env vars. Create a file named `hf_token.txt` in the
 
74
  # If a remote HF Inference API key is provided, we won't instantiate a local
75
  # heavy model; instead generation will be performed via the HTTP API.
76
  global USE_REMOTE_LLM, REMOTE_LLM_MODEL
77
+ # For Hugging Face Spaces deployment: prefer remote PHI inference
78
+ # This avoids memory issues on CPU-only spaces and provides better performance
79
  if USE_REMOTE_LLM:
80
+ logger.info(f"πŸ”„ Using remote Hugging Face Inference with PHI model: {REMOTE_LLM_MODEL}")
81
+ logger.info(f" βœ… PHI models are optimized for instruction-following and long-form generation")
82
  CONFIG["llm_model"] = REMOTE_LLM_MODEL
83
+ CONFIG["model_type"] = "remote_phi"
84
  return None
85
 
86
+ # Final fallback: attempt to initialize the free local T5 model (as before)
87
+ logger.info("πŸ”„ Initializing FREE local language model (fallback to T5)...")
88
  model_name = "google/flan-t5-large"
89
 
90
  try:
 
115
  """Call the Hugging Face Inference API for remote generation. Requires
116
  `HF_INFERENCE_API_KEY` env var to be set and a model name in
117
  `REMOTE_LLM_MODEL`.
118
+
119
+ PHI models work best with clear instruction formatting. This function
120
+ handles both the standard HF Inference API and PHI-specific response parsing.
121
  """
122
  if not HF_INFERENCE_API_KEY:
123
  raise Exception("HF_INFERENCE_API_KEY not set for remote generation")
124
 
125
+ # Use the HF Inference API endpoint (not router for better PHI compatibility)
126
+ api_url = f"https://api-inference.huggingface.co/models/{REMOTE_LLM_MODEL}"
127
+ headers = {"Authorization": f"Bearer {HF_INFERENCE_API_KEY}"}
128
+
129
+ # PHI models prefer simple parameters; avoid return_full_text which can cause issues
130
  payload = {
131
  "inputs": prompt,
132
  "parameters": {
133
  "max_new_tokens": max_new_tokens,
134
  "temperature": temperature,
135
  "top_p": top_p,
136
+ "do_sample": True,
137
+ "repetition_penalty": 1.1
138
  }
139
  }
140
 
141
+ logger.info(f" β†’ Remote PHI inference to {REMOTE_LLM_MODEL} (tokens={max_new_tokens}, temp={temperature})")
142
  try:
143
+ r = requests.post(api_url, headers=headers, json=payload, timeout=90)
144
  except Exception as e:
145
+ logger.error(f" βœ— Remote request failed: {e}")
146
+ return ""
147
+
148
+ if r.status_code == 503:
149
+ logger.warning(f" ⚠️ Model loading (503), retrying in 5s...")
150
+ import time
151
+ time.sleep(5)
152
  try:
153
+ r = requests.post(api_url, headers=headers, json=payload, timeout=90)
154
+ except Exception as e:
155
+ logger.error(f" βœ— Retry failed: {e}")
 
156
  return ""
157
 
158
  if r.status_code != 200:
159
+ logger.error(f" βœ— Remote inference error {r.status_code}: {r.text[:300]}")
160
  return ""
161
 
162
  result = r.json()
163
+
164
+ # Handle error responses
165
  if isinstance(result, dict) and result.get("error"):
166
  logger.error(f" βœ— Remote inference returned error: {result.get('error')}")
167
  return ""
168
 
169
+ # Parse the generated text from various response formats
170
+ generated_text = ""
171
+
172
  if isinstance(result, list) and result:
173
+ # HF Inference API returns [{"generated_text": "..."}]
174
  first = result[0]
175
  if isinstance(first, dict):
176
+ generated_text = first.get("generated_text", "")
177
+ else:
178
+ generated_text = str(first)
179
+ elif isinstance(result, dict) and "generated_text" in result:
180
+ generated_text = result["generated_text"]
181
+ else:
182
+ generated_text = str(result)
183
 
184
+ # Clean up: PHI may return the prompt + completion, extract only new text
185
+ generated_text = generated_text.strip()
186
+
187
+ # If the response contains the original prompt, extract only the new completion
188
+ if prompt in generated_text:
189
+ # Find where the prompt ends and new generation begins
190
+ prompt_end = generated_text.find(prompt) + len(prompt)
191
+ generated_text = generated_text[prompt_end:].strip()
192
+
193
+ return generated_text
194
 
195
  def initialize_embeddings():
196
  logger.info("πŸ”„ Initializing embeddings model...")