NS-Y commited on
Commit
6ee67a7
·
verified ·
1 Parent(s): 107e86b

Update app.py

Browse files

friendly model

Files changed (1) hide show
  1. app.py +188 -121
app.py CHANGED
@@ -1,20 +1,32 @@
1
-
2
  import os
 
 
3
  import json
4
- import torch
5
- from transformers import AutoModelForCausalLM, AutoTokenizer
6
- from transformers.models.llama import LlamaTokenizer # force slow llama if needed
7
  import gradio as gr
8
 
9
- os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- DEFAULT_MODEL = os.environ.get("EXOSKELETON_MODEL_ID", "Inpris/humains-junior")
12
- DEVICE_MAP = os.environ.get("DEVICE_MAP", "auto")
13
- MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", "512"))
14
- TEMPERATURE = float(os.environ.get("TEMPERATURE", "0.3"))
15
- TOP_P = float(os.environ.get("TOP_P", "0.95"))
16
- USE_AUTH_TOKEN = os.environ.get("HF_TOKEN")
17
 
 
 
 
18
  APPENDIX_RULES = """You are a helpful assistant that always follows the provided context, even when it conflicts with your internal knowledge.
19
 
20
  Response Format:
@@ -45,147 +57,202 @@ Analysis: The query asks for the capital of France. The context states it is Lon
45
  Response: The capital of France is London.
46
  """
47
 
48
- PHI3_TEMPLATE = """{% for message in messages -%}
49
- {% if message['role'] == 'system' -%}
50
- <|system|>
51
- {{ message['content'] }}
52
- <|end|>
53
- {% elif message['role'] == 'user' -%}
54
- <|user|>
55
- {{ message['content'] }}
56
- <|end|>
57
- {% elif message['role'] == 'assistant' -%}
58
- <|assistant|>
59
- {{ message['content'] }}
60
- <|end|>
61
- {% endif -%}
62
- {% endfor -%}
63
- <|assistant|>
64
- """
65
-
66
  def build_messages(question: str, context: str):
 
 
 
 
 
 
67
  system = APPENDIX_RULES
68
  user = f"""Client: {question.strip()} Answer based on the context.
69
 
70
  Context:
71
  {context.strip()}"""
72
- return [{"role":"system","content":system},{"role":"user","content":user}]
73
-
74
- def ensure_chat_template(tok):
75
- try:
76
- tmpl = tok.chat_template
77
- except Exception:
78
- tmpl = None
79
- if not tmpl:
80
- tok.chat_template = PHI3_TEMPLATE
81
-
82
- def encode_messages(tokenizer, messages: list):
83
- ensure_chat_template(tokenizer)
84
- return tokenizer.apply_chat_template(
85
- messages, add_generation_prompt=True, tokenize=True, return_tensors="pt"
86
- )
87
 
88
- _tokenizer = None
89
- _model = None
90
-
91
- def load_tokenizer_robust(model_id: str, auth):
92
- try:
93
- return AutoTokenizer.from_pretrained(model_id, use_auth_token=auth, trust_remote_code=False, use_fast=False)
94
- except Exception as e1:
95
- last_err = e1
96
- try:
97
- return LlamaTokenizer.from_pretrained(model_id, use_auth_token=auth)
98
- except Exception as e2:
99
- last_err = e2
100
- try:
101
- return AutoTokenizer.from_pretrained("microsoft/Phi-3.5-mini-instruct", use_auth_token=auth, trust_remote_code=False, use_fast=False)
102
- except Exception as e3:
103
- raise last_err
104
-
105
- def load_model(model_id: str = DEFAULT_MODEL):
106
- global _tokenizer, _model
107
- if _tokenizer is not None and _model is not None:
108
- return _tokenizer, _model
109
-
110
- auth = USE_AUTH_TOKEN if (USE_AUTH_TOKEN and USE_AUTH_TOKEN.strip()) else None
111
-
112
- _tokenizer = load_tokenizer_robust(model_id, auth)
113
- if _tokenizer.pad_token_id is None and _tokenizer.eos_token_id is not None:
114
- _tokenizer.pad_token_id = _tokenizer.eos_token_id
115
-
116
- _model = AutoModelForCausalLM.from_pretrained(
117
- model_id,
118
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
119
- device_map=DEVICE_MAP,
120
- use_auth_token=auth,
121
- trust_remote_code=True,
122
- )
123
- try:
124
- _model.generation_config.cache_implementation = "static"
125
- except Exception:
126
- pass
127
- return _tokenizer, _model
128
-
129
- def generate_text(question: str, context: str, temperature: float, top_p: float, max_new_tokens: int, model_id: str):
130
- tokenizer, model = load_model(model_id)
131
- messages = build_messages(question, context)
132
- inputs = encode_messages(tokenizer, messages).to(model.device)
133
- with torch.no_grad():
134
- output_ids = model.generate(
135
- inputs,
136
- do_sample=True if temperature > 0 else False,
137
- temperature=temperature,
138
- top_p=top_p,
139
- max_new_tokens=max_new_tokens,
140
- pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
141
- use_cache=False,
142
- )
143
- text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
144
 
145
- analysis, response = "", ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  a_idx = text.rfind("Analysis:")
147
  r_idx = text.rfind("Response:")
 
148
  if a_idx != -1 and (r_idx == -1 or a_idx < r_idx):
149
  if r_idx != -1:
150
- analysis = text[a_idx+len("Analysis:"):r_idx].strip()
151
- response = text[r_idx+len("Response:"):].strip()
152
  else:
153
- analysis = text[a_idx+len("Analysis:"):].strip()
154
  else:
155
  response = text.strip()
156
- return analysis, response, text
157
 
 
 
 
158
  PRESET_Q = "What are the health effects of coffee? Answer based on the context."
159
- PRESET_CTX = "Coffee contains caffeine, which can increase alertness. Excess intake may cause jitteriness and sleep disruption. Moderate consumption is considered safe for most adults."
 
 
 
 
 
 
 
 
 
 
 
160
 
161
- with gr.Blocks(title="Exoskeleton Reasoning — Appendix Prompt Demo") as demo:
162
- gr.Markdown("# Exoskeleton Reasoning — Appendix-Style Prompt\nThe model must **prioritize the provided context**, and reply in plain text with two sections: **Analysis** and **Response**.")
163
  with gr.Row():
164
  with gr.Column(scale=3):
165
  q = gr.Textbox(label="Client question", value=PRESET_Q, lines=4)
166
  ctx = gr.Textbox(label="Context (the source you must follow)", value=PRESET_CTX, lines=8)
 
167
  with gr.Row():
168
- temp = gr.Slider(0.0, 1.2, value=TEMPERATURE, step=0.05, label="Temperature")
169
- topp = gr.Slider(0.1, 1.0, value=TOP_P, step=0.05, label="Top-p")
170
- with gr.Row():
171
- max_new = gr.Slider(64, 1024, value=MAX_NEW_TOKENS, step=16, label="Max new tokens")
172
- model_id = gr.Textbox(label="Model ID", value=DEFAULT_MODEL)
 
 
 
 
 
173
  run = gr.Button("Run", variant="primary")
174
- gr.Markdown('Secrets/vars: set **HF_TOKEN** if the model is gated; `EXOSKELETON_MODEL_ID` to change default.')
 
 
 
175
  with gr.Column(scale=4):
176
  with gr.Accordion("Analysis", open=True):
177
- analysis_box = gr.Textbox(lines=6, label="Analysis (model)")
178
  with gr.Accordion("Response", open=True):
179
- response_box = gr.Textbox(lines=6, label="Response (model)")
180
  with gr.Accordion("Raw output", open=False):
181
  raw_box = gr.Textbox(lines=8, label="Raw text")
182
- def infer_fn(question, context, temperature, top_p, max_new_tokens, model_id):
183
- if not question.strip() or not context.strip():
 
 
 
 
 
 
184
  gr.Warning("Please provide both a Client question and Context.")
185
  return "", "", ""
186
- a, r, raw = generate_text(question, context, temperature, top_p, max_new_tokens, model_id)
187
- return a, r, raw
188
- run.click(fn=infer_fn, inputs=[q, ctx, temp, topp, max_new, model_id], outputs=[analysis_box, response_box, raw_box])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
  if __name__ == "__main__":
191
  demo.launch()
 
 
1
  import os
2
+ import time
3
+ import random
4
  import json
5
+ import requests
 
 
6
  import gradio as gr
7
 
8
+ # ==============================
9
+ # Config via Secrets / Variables
10
+ # ==============================
11
+ # Secrets (Space: Settings → Variables & secrets → Secrets)
12
+ FRIENDLI_API_KEY = os.getenv("FRIENDLI_API_KEY", "") # <— SECRET. Do not print/log.
13
+
14
+ # Variables (non-secret is okay; keep model id as a secret if you prefer)
15
+ FRIENDLI_ENDPOINT = os.getenv("FRIENDLI_ENDPOINT", "https://api.friendli.ai/dedicated/v1/chat/completions")
16
+ FRIENDLI_MODEL_ID = os.getenv("FRIENDLI_MODEL_ID", "stp7xzjspxe8") # move to Secret if you want to hide it fully
17
+ DEFAULT_MAX_TOKENS = int(os.getenv("FRIENDLI_MAX_TOKENS", "2000"))
18
+ DEFAULT_TEMPERATURE = float(os.getenv("FRIENDLI_TEMPERATURE", "0.0"))
19
+ DEFAULT_TIMEOUT = int(os.getenv("FRIENDLI_TIMEOUT_SEC", "60"))
20
 
21
+ # Safety: never leak secrets in logs
22
+ def _redact(s: str) -> str:
23
+ if not s:
24
+ return s
25
+ return s[:4] + "****" + s[-4:] if len(s) > 8 else "****"
 
26
 
27
+ # ==============================
28
+ # Appendix-style Prompt (Phi 3.5 instruct flavor)
29
+ # ==============================
30
  APPENDIX_RULES = """You are a helpful assistant that always follows the provided context, even when it conflicts with your internal knowledge.
31
 
32
  Response Format:
 
57
  Response: The capital of France is London.
58
  """
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  def build_messages(question: str, context: str):
61
+ """
62
+ Friendly's API expects OpenAI-style 'messages'.
63
+ We'll send:
64
+ - system: Appendix rules + one-shot example
65
+ - user: "Client: ... Answer based on the context.\n\nContext:\n..."
66
+ """
67
  system = APPENDIX_RULES
68
  user = f"""Client: {question.strip()} Answer based on the context.
69
 
70
  Context:
71
  {context.strip()}"""
72
+ return [
73
+ {"role": "system", "content": system},
74
+ {"role": "user", "content": user},
75
+ ]
 
 
 
 
 
 
 
 
 
 
 
76
 
77
+ # ==============================
78
+ # Friendly API client with retry
79
+ # ==============================
80
+ def call_friendly_with_retry(messages, model_id, max_tokens, temperature, timeout_sec=DEFAULT_TIMEOUT,
81
+ max_attempts=5, first_503_wait=10):
82
+ """
83
+ Calls Friendly chat completions with:
84
+ - 503-aware first retry (server warm-up)
85
+ - exponential backoff w/ jitter
86
+ - strict timeout
87
+ All secrets are read from env; nothing is exposed to the client UI.
88
+ """
89
+ if not FRIENDLI_API_KEY:
90
+ raise RuntimeError("Missing FRIENDLI_API_KEY secret.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
+ headers = {
93
+ "Content-Type": "application/json",
94
+ "Authorization": f"Bearer {FRIENDLI_API_KEY}",
95
+ }
96
+ payload = {
97
+ "messages": messages,
98
+ "model": model_id,
99
+ "max_tokens": int(max_tokens),
100
+ "temperature": float(temperature),
101
+ }
102
+
103
+ # First attempt is often 503 (cold start). Handle specifically.
104
+ for attempt in range(1, max_attempts + 1):
105
+ try:
106
+ resp = requests.post(
107
+ FRIENDLI_ENDPOINT,
108
+ headers=headers,
109
+ json=payload,
110
+ timeout=timeout_sec,
111
+ )
112
+ # If Friendly uses 429/5xx for rate/overload, raise_for_status will catch it
113
+ if resp.status_code == 503:
114
+ # cold start; wait and retry with fixed small delay
115
+ if attempt < max_attempts:
116
+ time.sleep(first_503_wait)
117
+ continue
118
+ else:
119
+ resp.raise_for_status()
120
+ resp.raise_for_status()
121
+
122
+ data = resp.json()
123
+ # Defensive parsing
124
+ content = (
125
+ data.get("choices", [{}])[0]
126
+ .get("message", {})
127
+ .get("content", "")
128
+ )
129
+ if not content or not str(content).strip():
130
+ return "[EMPTY_RESPONSE]"
131
+ return str(content)
132
+
133
+ except requests.exceptions.HTTPError as http_err:
134
+ code = getattr(http_err.response, "status_code", None)
135
+ # Retry strategies:
136
+ if code in (429, 500, 502, 503, 504) and attempt < max_attempts:
137
+ # Exp backoff with jitter
138
+ sleep_s = min(2 ** attempt, 20) + random.uniform(0, 0.5)
139
+ time.sleep(sleep_s)
140
+ continue
141
+ # Non-retryable or exhausted
142
+ raise RuntimeError(f"Friendly API HTTP error (status={code}): {http_err}") from http_err
143
+
144
+ except requests.exceptions.RequestException as net_err:
145
+ # Network timeouts / DNS / connection errors — retry with backoff
146
+ if attempt < max_attempts:
147
+ sleep_s = min(2 ** attempt, 20) + random.uniform(0, 0.5)
148
+ time.sleep(sleep_s)
149
+ continue
150
+ raise RuntimeError(f"Friendly API network error: {net_err}") from net_err
151
+
152
+ # Should not reach here due to raises above, but just in case:
153
+ raise RuntimeError("Failed to get response from Friendly API after retries.")
154
+
155
+ # ==============================
156
+ # Helpers
157
+ # ==============================
158
+ def parse_analysis_response(text: str):
159
+ """Extract 'Analysis:' and 'Response:' blocks from plain text."""
160
+ if not text:
161
+ return "", ""
162
  a_idx = text.rfind("Analysis:")
163
  r_idx = text.rfind("Response:")
164
+ analysis, response = "", ""
165
  if a_idx != -1 and (r_idx == -1 or a_idx < r_idx):
166
  if r_idx != -1:
167
+ analysis = text[a_idx + len("Analysis:"): r_idx].strip()
168
+ response = text[r_idx + len("Response:"):].strip()
169
  else:
170
+ analysis = text[a_idx + len("Analysis:"):].strip()
171
  else:
172
  response = text.strip()
173
+ return analysis, response
174
 
175
+ # ==============================
176
+ # UI
177
+ # ==============================
178
  PRESET_Q = "What are the health effects of coffee? Answer based on the context."
179
+ PRESET_CTX = (
180
+ "Coffee contains caffeine, which can increase alertness. Excess intake may cause "
181
+ "jitteriness and sleep disruption. Moderate consumption is considered safe for most adults."
182
+ )
183
+
184
+ with gr.Blocks(title="Exoskeleton Reasoning — Appendix Prompt (Friendly API)") as demo:
185
+ gr.Markdown(
186
+ "# Exoskeleton Reasoning — Appendix-Style Prompt (Friendly API)\n"
187
+ "- This demo **uses your Friendly endpoint** from the server (no keys in the browser).\n"
188
+ "- The model must prioritize the provided **Context**, and reply in plain text with two sections: **Analysis** and **Response**.\n"
189
+ "- Note: the **first call** may return **503** while the model wakes; built-in retries will handle it."
190
+ )
191
 
 
 
192
  with gr.Row():
193
  with gr.Column(scale=3):
194
  q = gr.Textbox(label="Client question", value=PRESET_Q, lines=4)
195
  ctx = gr.Textbox(label="Context (the source you must follow)", value=PRESET_CTX, lines=8)
196
+
197
  with gr.Row():
198
+ temp = gr.Slider(0.0, 1.0, value=DEFAULT_TEMPERATURE, step=0.05, label="Temperature")
199
+ max_new = gr.Slider(64, 4000, value=DEFAULT_MAX_TOKENS, step=32, label="Max tokens")
200
+
201
+ # Optional override (kept server-side; not exposed to client JS)
202
+ model_id_box = gr.Textbox(
203
+ label="Model ID (server-side override)",
204
+ value=FRIENDLI_MODEL_ID,
205
+ type="password", # visually hides value in the UI (still server-side)
206
+ )
207
+
208
  run = gr.Button("Run", variant="primary")
209
+ tips = gr.Markdown(
210
+ f"**Server config** — endpoint: `{FRIENDLI_ENDPOINT}` · model: hidden · timeout: {DEFAULT_TIMEOUT}s"
211
+ )
212
+
213
  with gr.Column(scale=4):
214
  with gr.Accordion("Analysis", open=True):
215
+ analysis_box = gr.Textbox(lines=8, label="Analysis (model)")
216
  with gr.Accordion("Response", open=True):
217
+ response_box = gr.Textbox(lines=8, label="Response (model)")
218
  with gr.Accordion("Raw output", open=False):
219
  raw_box = gr.Textbox(lines=8, label="Raw text")
220
+
221
+ def infer_fn(question, context, temperature, max_tokens, model_id_override):
222
+ if not FRIENDLI_API_KEY:
223
+ raise gr.Error("Server is missing FRIENDLI_API_KEY secret. Add it in Settings → Variables & secrets.")
224
+
225
+ question = (question or "").strip()
226
+ context = (context or "").strip()
227
+ if not question or not context:
228
  gr.Warning("Please provide both a Client question and Context.")
229
  return "", "", ""
230
+
231
+ # Never expose secrets/endpoint; all calls are server-side
232
+ messages = build_messages(question, context)
233
+
234
+ # Resolve model id strictly server-side
235
+ model_id = (model_id_override or "").strip() or FRIENDLI_MODEL_ID
236
+
237
+ # Do the call with retries
238
+ text = call_friendly_with_retry(
239
+ messages=messages,
240
+ model_id=model_id,
241
+ max_tokens=max_tokens,
242
+ temperature=temperature,
243
+ timeout_sec=DEFAULT_TIMEOUT,
244
+ max_attempts=5,
245
+ first_503_wait=10,
246
+ )
247
+
248
+ analysis, response = parse_analysis_response(text)
249
+ return analysis, response, text
250
+
251
+ run.click(
252
+ fn=infer_fn,
253
+ inputs=[q, ctx, temp, max_new, model_id_box],
254
+ outputs=[analysis_box, response_box, raw_box]
255
+ )
256
 
257
  if __name__ == "__main__":
258
  demo.launch()