iamthewalrus67 commited on
Commit
b04ad10
·
1 Parent(s): 4921d80

Rewrite to work with score models

Browse files
Files changed (1) hide show
  1. app.py +62 -264
app.py CHANGED
@@ -27,288 +27,86 @@ from huggingface_hub import login
27
  login(token=HF_LE_LLM_READ_TOKEN)
28
 
29
  # Constants
30
- # MODEL_ID = "le-llm/lapa-v0.1-reasoning-only-32768"
31
- # MODEL_ID = "le-llm/lapa-v0.1-instruct"
32
- # MODEL_ID = "le-llm/lapa-v0.1-matt-instruction-5e06"
33
- # MODEL_ID = "le-llm/lapa-v0.1-reprojected"
34
- # MODEL_ID = "le-llm/lapa-v0.1.1-instruct"
35
- MODEL_ID = "le-llm/manipulative-score-model"
36
 
37
- MAX_TOKENS = 4096
38
- TEMPERATURE = 0.7
39
- TOP_P = 0.95
40
 
41
- IMAGE_MAX_SIZE = 1024
 
42
 
43
 
44
- def _begin_analytics_session():
45
- # Called once per client on app load
46
- pass
47
- # _ = logger.start_session(MODEL_ID)
48
 
49
- def load_model():
50
- """Lazy-load model, tokenizer, and optional processor (for zeroGPU)."""
51
- device = "cuda" # if torch.cuda.is_available() else "cpu"
52
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
53
- processor = None
54
- try:
55
- processor = AutoProcessor.from_pretrained(MODEL_ID)
56
- except Exception as err: # pragma: no cover - informative fallback
57
- print(f"Warning: AutoProcessor not available ({err}). Falling back to tokenizer.")
58
 
59
- model = AutoModel.from_pretrained(
60
- MODEL_ID,
61
- dtype=torch.bfloat16, # if device == "cuda" else torch.float32,
62
- device_map="cuda", # if device == "cuda" else None,
63
- ) # .cuda()
64
- print(f"Selected device:", device)
65
- return model, tokenizer, processor, device
66
 
 
 
 
 
67
 
68
- # Load model/tokenizer each request → allows zeroGPU to cold start & then release
69
- model, tokenizer, processor, device = load_model()
70
 
71
-
72
- def user(user_message, history: list):
73
- """Format user message with optional image."""
74
- import io
75
-
76
- user_message = user_message or ""
77
- updated_history = list(history)
78
- has_content = False
79
-
80
- stripped_message = user_message.strip()
81
-
82
- if stripped_message:
83
- has_content = True
84
-
85
- if not has_content:
86
- # Nothing to submit yet; keep inputs unchanged
87
- return user_message, history
88
-
89
- return "", updated_history
90
-
91
-
92
- def append_example_message(x: gr.SelectData, history):
93
- if x.value["text"] is not None:
94
- history.append({"role": "user", "content": x.value["text"]})
95
-
96
- return history
97
-
98
-
99
- def _extract_text_from_content(content: Any) -> str | tuple[str, str]:
100
- """Extract text from message content for logging."""
101
- if isinstance(content, str):
102
- return content
103
- if isinstance(content, tuple) and len(content) == 2:
104
- return content # (image_path, user_text)
105
-
106
- raise ValueError(f"Unsupported content type for text extraction: {content}")
107
-
108
-
109
- def _clean_history_for_display(history: list[dict[str, Any]]) -> list[dict[str, Any]]:
110
- """Remove internal metadata fields like _base64 before displaying in Gradio."""
111
- cleaned = []
112
- for message in history:
113
- cleaned_message = {"role": message.get("role", "user")}
114
- content = message.get("content")
115
-
116
- if isinstance(content, str):
117
- cleaned_message["content"] = content
118
- elif isinstance(content, list):
119
- cleaned_content = []
120
- for item in content:
121
- if isinstance(item, dict):
122
- # Remove _base64 metadata
123
- cleaned_item = {k: v for k, v in item.items() if not k.startswith("_")}
124
- cleaned_content.append(cleaned_item)
125
- else:
126
- cleaned_content.append(item)
127
- cleaned_message["content"] = cleaned_content
128
- else:
129
- cleaned_message["content"] = content
130
-
131
- cleaned.append(cleaned_message)
132
-
133
- return cleaned
134
 
135
 
 
136
  @spaces.GPU
137
- def bot(
138
- input: list[dict[str, Any]]
139
- ):
140
- """Generate bot response with support for text."""
141
-
142
- # Early return if no input
143
- if not input:
144
- return
145
-
146
- clean_input = f"query: {input}"
147
- batch_dict = tokenizer(clean_input, max_length=512, padding=True, truncation=True, return_tensors='pt')
148
- batch_dict = {k: v.to(device) for k, v in batch_dict.items()}
149
-
150
- outputs = model(**batch_dict)
151
- embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
152
-
153
- embeddings = F.normalize(embeddings, p=2, dim=1)
154
- scores = (embeddings[:2] @ embeddings[2:].T) * 100
155
- return str(scores.tolist())
156
-
 
 
 
 
 
 
 
 
 
 
157
 
158
- def average_pool(last_hidden_states: torch.Tensor,
159
- attention_mask: torch.Tensor) -> torch.Tensor:
160
- last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
161
- return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
162
-
163
- # --- drop-in UI compatible with older Gradio versions ---
164
- import os, tempfile, time
165
- import gradio as gr
166
 
167
- # Ukrainian-inspired theme with deep, muted colors reflecting unbeatable spirit:
168
- THEME = gr.themes.Soft(
169
- primary_hue="blue", # Deep blue representing Ukrainian sky and resolve
170
- secondary_hue="amber", # Warm amber representing golden fields and determination
171
- neutral_hue="stone", # Earthy stone representing strength and foundation
172
- )
173
 
174
- # Load CSS from external file
175
- def load_css():
176
- try:
177
- with open("static/style.css", "r", encoding="utf-8") as f:
178
- return f.read()
179
- except FileNotFoundError:
180
- print("Warning: static/style.css not found")
181
- return ""
182
 
183
- CSS = load_css()
 
184
 
185
- def _clear_chat():
186
- return "", []
 
187
 
188
- with gr.Blocks(theme=THEME, css=CSS, fill_height=True) as demo:
189
- demo.load(fn=_begin_analytics_session, inputs=None, outputs=None)
190
-
191
-
192
- # Header (no gr.Box to avoid version issues)
193
- gr.HTML(
194
- """
195
- <div id="app-header">
196
- <div class="app-title">🤔 LAPA Quality Estimation</div>
197
- </div>
198
- """
199
- )
200
-
201
- with gr.Row(equal_height=True):
202
- # Left side: Chat
203
- with gr.Column(scale=7, elem_id="left-pane"):
204
- with gr.Column(elem_id="chat-card"):
205
- chatbot = gr.Chatbot(
206
- type="messages",
207
- height=560,
208
- render_markdown=True,
209
- show_copy_button=True,
210
- show_label=False,
211
- # likeable=True,
212
- allow_tags=["think"],
213
- elem_id="chatbot",
214
- examples=[
215
- {"text": i}
216
- for i in [
217
- "хто тримає цей район?",
218
- "Напиши історію про Івасика-Телесика",
219
- "Яка найвища гора в Україні?",
220
- "Як звали батька Тараса Григоровича Шевченка?",
221
- "Яка з цих гір не знаходиться у Європі? Говерла, Монблан, Гран-Парадізо, Еверест",
222
- "Дай відповідь на питання\nЧому у качки жовті ноги?",
223
- ]
224
- ],
225
- )
226
-
227
-
228
- # ChatGPT-style input box with stop button
229
- with gr.Row(elem_id="chat-input-row"):
230
- msg = gr.Textbox(
231
- label=None,
232
- placeholder="Message… (Press Enter to send)",
233
- autofocus=True,
234
- lines=1,
235
- max_lines=6,
236
- container=False,
237
- show_label=False,
238
- elem_id="chat-input",
239
- elem_classes=["chat-input-box"]
240
- )
241
- stop_btn_visible = gr.Button(
242
- "⏹️",
243
- variant="secondary",
244
- elem_id="stop-btn-visible",
245
- elem_classes=["stop-btn-chat"],
246
- visible=False,
247
- size="sm"
248
- )
249
-
250
- # Hidden buttons for functionality
251
- with gr.Row(visible=True, elem_id="hidden-buttons"):
252
- send_btn = gr.Button("Send", variant="primary", elem_id="send-btn")
253
- stop_btn = gr.Button("Stop", variant="secondary", elem_id="stop-btn")
254
- clear_btn = gr.Button("Clear", variant="secondary", elem_id="clear-btn")
255
-
256
- # export_btn = gr.Button("Export chat (.md)", variant="secondary", elem_classes=["rounded-btn","secondary-btn"])
257
- # exported_file = gr.File(label="", interactive=False, visible=True)
258
- gr.HTML('<div class="footer-tip">Shortcuts: Enter to send • Shift+Enter for new line</div>')
259
-
260
- # Helper functions for managing UI state
261
- def show_stop_button():
262
- return gr.update(visible=True)
263
-
264
- def hide_stop_button():
265
- return gr.update(visible=False)
266
-
267
- # Events (preserve your original handlers)
268
- e1 = msg.submit(fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=True).then(
269
- fn=show_stop_button, inputs=None, outputs=stop_btn_visible
270
- ).then(
271
- fn=bot, inputs=chatbot, outputs=chatbot
272
- ).then(
273
- fn=hide_stop_button, inputs=None, outputs=stop_btn_visible
274
- )
275
-
276
- e2 = send_btn.click(fn=user, inputs=[msg, chatbot], outputs=[msg,chatbot], queue=True).then(
277
- fn=show_stop_button, inputs=None, outputs=stop_btn_visible
278
- ).then(
279
- fn=bot, inputs=chatbot, outputs=chatbot
280
- ).then(
281
- fn=hide_stop_button, inputs=None, outputs=stop_btn_visible
282
- )
283
-
284
- e3 = chatbot.example_select(fn=append_example_message, inputs=[chatbot], outputs=[chatbot], queue=True).then(
285
- fn=show_stop_button, inputs=None, outputs=stop_btn_visible
286
- ).then(
287
- fn=bot, inputs=chatbot, outputs=chatbot
288
- ).then(
289
- fn=hide_stop_button, inputs=None, outputs=stop_btn_visible
290
- )
291
-
292
- # Stop cancels running events (both buttons work)
293
- stop_btn.click(fn=hide_stop_button, inputs=None, outputs=stop_btn_visible, cancels=[e1, e2, e3], queue=True)
294
- stop_btn_visible.click(fn=hide_stop_button, inputs=None, outputs=stop_btn_visible, cancels=[e1, e2, e3], queue=True)
295
-
296
- # Clear chat + input
297
- clear_btn.click(fn=_clear_chat, inputs=None, outputs=[msg, chatbot])
298
-
299
- # Export markdown
300
- # export_btn.click(fn=_export_markdown, inputs=chatbot, outputs=exported_file)
301
-
302
- # Load and inject external JavaScript
303
- def load_javascript():
304
- try:
305
- with open("static/script.js", "r", encoding="utf-8") as f:
306
- return f"<script>{f.read()}</script>"
307
- except FileNotFoundError:
308
- print("Warning: static/script.js not found")
309
- return ""
310
-
311
- gr.HTML(load_javascript())
312
 
313
  if __name__ == "__main__":
314
  demo.queue().launch()
 
27
  login(token=HF_LE_LLM_READ_TOKEN)
28
 
29
  # Constants
30
+ DEFAULT_MODEL = "le-llm/manipulative-score-model"
31
+ DEVICE = "cuda"
 
 
 
 
32
 
 
 
 
33
 
34
+ # --- Cache to avoid repeated reloads ---
35
+ _model_cache: Dict[str, tuple[torch.nn.Module, AutoTokenizer]] = {}
36
 
37
 
38
+ def load_model(model_id: str):
39
+ """Load model + tokenizer, auto-detect whether it's embedding or regression."""
40
+ if model_id in _model_cache:
41
+ return _model_cache[model_id]
42
 
43
+ print(f"🔹 Loading model: {model_id}")
44
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
 
 
 
 
 
 
 
45
 
46
+ model = AutoModel.from_pretrained(model_id, torch_dtype=torch.bfloat16)
47
+ print(f"Detected embedding model: {model_id}")
 
 
 
 
 
48
 
49
+ model.to(DEVICE).eval()
50
+ _model_cache[model_id] = (model, tokenizer)
51
+ print(f"✅ Loaded model on {DEVICE}")
52
+ return model, tokenizer
53
 
 
 
54
 
55
+ # --- Helper for embeddings ---
56
+ def average_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
57
+ last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
58
+ return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
 
61
+ # --- Main scoring logic ---
62
  @spaces.GPU
63
+ def bot(user_message: str, history: list[dict[str, Any]], model_choice: str):
64
+ if not user_message.strip():
65
+ return "", history
66
+
67
+ model, tokenizer = load_model(model_choice) # returns embedding model
68
+ history = history + [{"role": "user", "content": user_message}]
69
+
70
+ batch = tokenizer([user_message], padding=True, truncation=True, return_tensors="pt").to(DEVICE)
71
+ with torch.no_grad():
72
+ outputs = model(**batch)
73
+ # outputs.last_hidden_state.shape = [batch_size, seq_len, hidden_dim]
74
+ # average pool over tokens
75
+ embedding = average_pool(outputs.last_hidden_state, batch["attention_mask"])
76
+ score = model.score_head(embedding).squeeze().item()
77
+ # embedding = F.normalize(embedding, p=2, dim=1) # optional
78
+ #
79
+ # # Compute scalar score from embedding (example: mean of embedding dims)
80
+ # score = embedding.mean().item()
81
+
82
+ response = f"🔹 {model_choice} → score: {score:.4f}"
83
+ history.append({"role": "assistant", "content": response})
84
+ return "", history
85
+
86
+ # --- UI ---
87
+ THEME = gr.themes.Soft(primary_hue="blue", secondary_hue="amber", neutral_hue="stone")
88
+
89
+ MODEL_OPTIONS = [
90
+ "le-llm/manipulative-score-model",
91
+ "le-llm/gec-score-model"
92
+ ]
93
 
94
+ def _clear_chat():
95
+ return "", []
 
 
 
 
 
 
96
 
 
 
 
 
 
 
97
 
98
+ with gr.Blocks(theme=THEME, fill_height=True) as demo:
99
+ gr.Markdown("### 🤔 LAPA Quality Estimation")
 
 
 
 
 
 
100
 
101
+ with gr.Row():
102
+ model_choice = gr.Dropdown(MODEL_OPTIONS, value=DEFAULT_MODEL, label="Select Model")
103
 
104
+ chatbot = gr.Chatbot(type="messages", height=480)
105
+ msg = gr.Textbox(label=None, placeholder="Type your text…", lines=1)
106
+ clear_btn = gr.Button("Clear")
107
 
108
+ msg.submit(bot, inputs=[msg, chatbot, model_choice], outputs=[msg, chatbot])
109
+ clear_btn.click(_clear_chat, outputs=[msg, chatbot])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
  if __name__ == "__main__":
112
  demo.queue().launch()