NS-Y commited on
Commit
1da1de0
·
verified ·
1 Parent(s): 74419dd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -23
app.py CHANGED
@@ -1,18 +1,24 @@
1
-
2
  import os
3
  import torch
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
  import gradio as gr
6
 
 
 
 
 
 
7
  DEFAULT_MODEL = os.environ.get("EXOSKELETON_MODEL_ID", "Inpris/humains-junior")
8
- TRUST_REMOTE_CODE = os.environ.get("TRUST_REMOTE_CODE", "1") == "1"
9
  DEVICE_MAP = os.environ.get("DEVICE_MAP", "auto")
10
  MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", "512"))
11
  TEMPERATURE = float(os.environ.get("TEMPERATURE", "0.3"))
12
  TOP_P = float(os.environ.get("TOP_P", "0.95"))
13
- USE_AUTH_TOKEN = os.environ.get("HF_TOKEN", None)
14
 
15
- SYSTEM_PROMPT = """You are a helpful assistant that always follows the provided context, even when it conflicts with your internal knowledge.
 
 
 
16
 
17
  Response Format:
18
  Before answering, briefly analyze the query and context:
@@ -42,15 +48,21 @@ Analysis: The query asks for the capital of France. The context states it is Lon
42
  Response: The capital of France is London.
43
  """
44
 
45
- def build_prompt(question: str, context: str) -> str:
46
- return f"""{SYSTEM_PROMPT}
47
-
48
- Client: {question.strip()} Answer based on the context.
49
 
50
  Context:
51
- {context.strip()}
52
- """
53
-
 
 
 
 
 
 
54
  _tokenizer = None
55
  _model = None
56
 
@@ -59,19 +71,30 @@ def load_model(model_id: str = DEFAULT_MODEL):
59
  if _tokenizer is not None and _model is not None:
60
  return _tokenizer, _model
61
 
62
- auth = USE_AUTH_TOKEN if (USE_AUTH_TOKEN and len(USE_AUTH_TOKEN.strip()) > 0) else None
63
- _tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=auth, trust_remote_code=TRUST_REMOTE_CODE, use_fast=False)
 
 
 
 
 
 
 
 
 
 
64
  _model = AutoModelForCausalLM.from_pretrained(
65
  model_id,
66
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
67
  device_map=DEVICE_MAP,
68
  use_auth_token=auth,
69
- trust_remote_code=TRUST_REMOTE_CODE,
70
  )
71
 
72
  if _tokenizer.pad_token_id is None and _tokenizer.eos_token_id is not None:
73
  _tokenizer.pad_token_id = _tokenizer.eos_token_id
74
 
 
75
  try:
76
  _model.generation_config.cache_implementation = "static"
77
  except Exception:
@@ -79,22 +102,66 @@ def load_model(model_id: str = DEFAULT_MODEL):
79
 
80
  return _tokenizer, _model
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  def generate_text(question: str, context: str, temperature: float, top_p: float, max_new_tokens: int, model_id: str):
83
  tokenizer, model = load_model(model_id)
84
- prompt = build_prompt(question, context)
85
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
 
86
  with torch.no_grad():
87
  output_ids = model.generate(
88
- **inputs,
89
  do_sample=True if temperature > 0 else False,
90
  temperature=temperature,
91
  top_p=top_p,
92
  max_new_tokens=max_new_tokens,
93
  pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
94
- use_cache=False,
95
  )
96
  text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
97
 
 
98
  analysis, response = "", ""
99
  a_idx = text.rfind("Analysis:")
100
  r_idx = text.rfind("Response:")
@@ -108,11 +175,20 @@ def generate_text(question: str, context: str, temperature: float, top_p: float,
108
  response = text.strip()
109
  return analysis, response, text
110
 
 
 
 
111
  PRESET_Q = "What are the health effects of coffee? Answer based on the context."
112
- PRESET_CTX = "Coffee contains caffeine, which can increase alertness. Excess intake may cause jitteriness and sleep disruption. Moderate consumption is considered safe for most adults."
 
 
 
113
 
114
  with gr.Blocks(title="Exoskeleton Reasoning — Appendix Prompt Demo") as demo:
115
- gr.Markdown("# Exoskeleton Reasoning — Appendix-Style Prompt\nThe model must **prioritize the provided context**, and reply in plain text with two sections: **Analysis** and **Response**.")
 
 
 
116
  with gr.Row():
117
  with gr.Column(scale=3):
118
  q = gr.Textbox(label="Client question", value=PRESET_Q, lines=4)
@@ -124,7 +200,9 @@ with gr.Blocks(title="Exoskeleton Reasoning — Appendix Prompt Demo") as demo:
124
  max_new = gr.Slider(64, 1024, value=MAX_NEW_TOKENS, step=16, label="Max new tokens")
125
  model_id = gr.Textbox(label="Model ID", value=DEFAULT_MODEL)
126
  run = gr.Button("Run", variant="primary")
127
- gr.Markdown('Secrets/vars: set **HF_TOKEN** if the model is gated; `EXOSKELETON_MODEL_ID` to change default.')
 
 
128
  with gr.Column(scale=4):
129
  with gr.Accordion("Analysis", open=True):
130
  analysis_box = gr.Textbox(lines=6, label="Analysis (model)")
@@ -132,13 +210,16 @@ with gr.Blocks(title="Exoskeleton Reasoning — Appendix Prompt Demo") as demo:
132
  response_box = gr.Textbox(lines=6, label="Response (model)")
133
  with gr.Accordion("Raw output", open=False):
134
  raw_box = gr.Textbox(lines=8, label="Raw text")
 
135
  def infer_fn(question, context, temperature, top_p, max_new_tokens, model_id):
136
- if not question or not question.strip() or not context or not context.strip():
137
  gr.Warning("Please provide both a Client question and Context.")
138
  return "", "", ""
139
  a, r, raw = generate_text(question, context, temperature, top_p, max_new_tokens, model_id)
140
  return a, r, raw
141
- run.click(fn=infer_fn, inputs=[q, ctx, temp, topp, max_new, model_id], outputs=[analysis_box, response_box, raw_box])
 
 
142
 
143
  if __name__ == "__main__":
144
  demo.launch()
 
 
1
  import os
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import gradio as gr
5
 
6
+ # -----------------------------
7
+ # Config
8
+ # -----------------------------
9
+ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
10
+
11
  DEFAULT_MODEL = os.environ.get("EXOSKELETON_MODEL_ID", "Inpris/humains-junior")
 
12
  DEVICE_MAP = os.environ.get("DEVICE_MAP", "auto")
13
  MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", "512"))
14
  TEMPERATURE = float(os.environ.get("TEMPERATURE", "0.3"))
15
  TOP_P = float(os.environ.get("TOP_P", "0.95"))
16
+ USE_AUTH_TOKEN = os.environ.get("HF_TOKEN") # optional for gated repos
17
 
18
+ # -----------------------------
19
+ # Appendix-style rules + Phi-3.5 instruct chat prompt
20
+ # -----------------------------
21
+ APPENDIX_RULES = """You are a helpful assistant that always follows the provided context, even when it conflicts with your internal knowledge.
22
 
23
  Response Format:
24
  Before answering, briefly analyze the query and context:
 
48
  Response: The capital of France is London.
49
  """
50
 
51
+ def build_messages(question: str, context: str):
52
+ """Phi-3.5-instruct style: system + user; we keep a 1-shot in the system block as in Appendix."""
53
+ system = APPENDIX_RULES
54
+ user = f"""Client: {question.strip()} Answer based on the context.
55
 
56
  Context:
57
+ {context.strip()}"""
58
+ return [
59
+ {"role": "system", "content": system},
60
+ {"role": "user", "content": user},
61
+ ]
62
+
63
+ # -----------------------------
64
+ # Model loading (use the repo's own tokenizer)
65
+ # -----------------------------
66
  _tokenizer = None
67
  _model = None
68
 
 
71
  if _tokenizer is not None and _model is not None:
72
  return _tokenizer, _model
73
 
74
+ auth = USE_AUTH_TOKEN if (USE_AUTH_TOKEN and USE_AUTH_TOKEN.strip()) else None
75
+
76
+ # IMPORTANT:
77
+ # - trust_remote_code=True so custom tokenizer/model classes from the repo are used.
78
+ # - use_fast=False to avoid tokenizer.json schema mismatches; many custom repos only ship a slow tokenizer.
79
+ _tokenizer = AutoTokenizer.from_pretrained(
80
+ model_id,
81
+ use_auth_token=auth,
82
+ trust_remote_code=True,
83
+ use_fast=False,
84
+ )
85
+
86
  _model = AutoModelForCausalLM.from_pretrained(
87
  model_id,
88
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
89
  device_map=DEVICE_MAP,
90
  use_auth_token=auth,
91
+ trust_remote_code=True,
92
  )
93
 
94
  if _tokenizer.pad_token_id is None and _tokenizer.eos_token_id is not None:
95
  _tokenizer.pad_token_id = _tokenizer.eos_token_id
96
 
97
+ # Prefer a static cache; and we will pass use_cache=False at generation to avoid DynamicCache issues
98
  try:
99
  _model.generation_config.cache_implementation = "static"
100
  except Exception:
 
102
 
103
  return _tokenizer, _model
104
 
105
+ # -----------------------------
106
+ # Prompting via chat template
107
+ # -----------------------------
108
+ # If the repo doesn't ship a chat template, we inject a Phi-3.5-instruct style template.
109
+ PHI3_TEMPLATE = """{% for message in messages -%}
110
+ {% if message['role'] == 'system' -%}
111
+ <|system|>
112
+ {{ message['content'] }}
113
+ <|end|>
114
+ {% elif message['role'] == 'user' -%}
115
+ <|user|>
116
+ {{ message['content'] }}
117
+ <|end|>
118
+ {% elif message['role'] == 'assistant' -%}
119
+ <|assistant|>
120
+ {{ message['content'] }}
121
+ <|end|>
122
+ {% endif -%}
123
+ {% endfor -%}
124
+ <|assistant|>
125
+ """
126
+
127
+ def ensure_chat_template(tok):
128
+ try:
129
+ tmpl = tok.chat_template
130
+ except Exception:
131
+ tmpl = None
132
+ if not tmpl:
133
+ tok.chat_template = PHI3_TEMPLATE
134
+
135
+ def encode_messages(tokenizer, messages: list):
136
+ ensure_chat_template(tokenizer)
137
+ return tokenizer.apply_chat_template(
138
+ messages,
139
+ add_generation_prompt=True,
140
+ tokenize=True,
141
+ return_tensors="pt"
142
+ )
143
+
144
+ # -----------------------------
145
+ # Generation
146
+ # -----------------------------
147
  def generate_text(question: str, context: str, temperature: float, top_p: float, max_new_tokens: int, model_id: str):
148
  tokenizer, model = load_model(model_id)
149
+ messages = build_messages(question, context)
150
+ inputs = encode_messages(tokenizer, messages).to(model.device)
151
+
152
  with torch.no_grad():
153
  output_ids = model.generate(
154
+ inputs,
155
  do_sample=True if temperature > 0 else False,
156
  temperature=temperature,
157
  top_p=top_p,
158
  max_new_tokens=max_new_tokens,
159
  pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
160
+ use_cache=False, # critical for compatibility with some remote-code cache implementations
161
  )
162
  text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
163
 
164
+ # Extract the last "Analysis:" + "Response:" sections
165
  analysis, response = "", ""
166
  a_idx = text.rfind("Analysis:")
167
  r_idx = text.rfind("Response:")
 
175
  response = text.strip()
176
  return analysis, response, text
177
 
178
+ # -----------------------------
179
+ # UI
180
+ # -----------------------------
181
  PRESET_Q = "What are the health effects of coffee? Answer based on the context."
182
+ PRESET_CTX = (
183
+ "Coffee contains caffeine, which can increase alertness. Excess intake may cause "
184
+ "jitteriness and sleep disruption. Moderate consumption is considered safe for most adults."
185
+ )
186
 
187
  with gr.Blocks(title="Exoskeleton Reasoning — Appendix Prompt Demo") as demo:
188
+ gr.Markdown(
189
+ "# Exoskeleton Reasoning — Appendix-Style Prompt\n"
190
+ "The model must **prioritize the provided context**, and reply in plain text with two sections: **Analysis** and **Response**."
191
+ )
192
  with gr.Row():
193
  with gr.Column(scale=3):
194
  q = gr.Textbox(label="Client question", value=PRESET_Q, lines=4)
 
200
  max_new = gr.Slider(64, 1024, value=MAX_NEW_TOKENS, step=16, label="Max new tokens")
201
  model_id = gr.Textbox(label="Model ID", value=DEFAULT_MODEL)
202
  run = gr.Button("Run", variant="primary")
203
+ gr.Markdown(
204
+ 'Secrets/vars: set **HF_TOKEN** if the model is gated · Override `EXOSKELETON_MODEL_ID` to change default.'
205
+ )
206
  with gr.Column(scale=4):
207
  with gr.Accordion("Analysis", open=True):
208
  analysis_box = gr.Textbox(lines=6, label="Analysis (model)")
 
210
  response_box = gr.Textbox(lines=6, label="Response (model)")
211
  with gr.Accordion("Raw output", open=False):
212
  raw_box = gr.Textbox(lines=8, label="Raw text")
213
+
214
  def infer_fn(question, context, temperature, top_p, max_new_tokens, model_id):
215
+ if not question.strip() or not context.strip():
216
  gr.Warning("Please provide both a Client question and Context.")
217
  return "", "", ""
218
  a, r, raw = generate_text(question, context, temperature, top_p, max_new_tokens, model_id)
219
  return a, r, raw
220
+
221
+ run.click(fn=infer_fn, inputs=[q, ctx, temp, topp, max_new, model_id],
222
+ outputs=[analysis_box, response_box, raw_box])
223
 
224
  if __name__ == "__main__":
225
  demo.launch()