davda54 commited on
Commit
90fa8f8
·
verified ·
1 Parent(s): 086869b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +161 -192
app.py CHANGED
@@ -9,7 +9,8 @@ from typing import Dict, List, Tuple
9
  import hashlib
10
  import itertools
11
  from datasets import load_dataset, Dataset, DatasetDict
12
- from huggingface_hub import HfApi, create_repo, repo_exists
 
13
  import threading
14
 
15
  from collections.abc import Iterable
@@ -218,8 +219,8 @@ TODO
218
  """
219
 
220
  # Configuration for the output dataset
221
- OUTPUT_DATASET_NAME = "ltg/fluency-annotations" # Change to your desired dataset name
222
- OUTPUT_DATASET_PRIVATE = True # Keep the annotations dataset private
223
 
224
  HF_TOKEN = os.environ.get("HF_TOKEN")
225
 
@@ -229,6 +230,83 @@ MODEL_NAMES = ["mistral-Nemo", "translated-SFT", "on-policy-RL"]
229
  # Create all pairwise comparisons
230
  MODEL_PAIRS = list(itertools.combinations(MODEL_NAMES, 2))
231
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  def load_dataset_samples():
233
  """Load and prepare dataset samples with pairwise comparisons"""
234
  try:
@@ -293,99 +371,17 @@ DATASET_SAMPLES = load_dataset_samples()
293
 
294
  class AnnotationManager:
295
  def __init__(self):
296
- self.annotations = {} # Store annotations by user_id
297
- self.user_states = {} # Track each user's progress
298
- self.annotation_cache = [] # Cache for batch uploads
299
- self.lock = threading.Lock() # Thread safety for annotations
300
-
301
- # Initialize or load existing annotations dataset
302
- self.init_annotations_dataset()
303
 
304
- def init_annotations_dataset(self):
305
- """Initialize or load existing annotations from HuggingFace"""
306
- try:
307
- if HF_TOKEN:
308
- api = HfApi(token=HF_TOKEN)
309
-
310
- # Check if dataset exists, if not create it
311
- if not repo_exists(OUTPUT_DATASET_NAME, repo_type="dataset", token=HF_TOKEN):
312
- print(f"Creating new dataset: {OUTPUT_DATASET_NAME}")
313
- create_repo(
314
- OUTPUT_DATASET_NAME,
315
- repo_type="dataset",
316
- private=OUTPUT_DATASET_PRIVATE,
317
- token=HF_TOKEN
318
- )
319
- # Create empty dataset structure
320
- self.push_empty_dataset()
321
- else:
322
- # Load existing annotations
323
- print(f"Loading existing annotations from {OUTPUT_DATASET_NAME}")
324
- self.load_existing_annotations()
325
- else:
326
- print("Warning: No HF_TOKEN found. Annotations will only be saved locally.")
327
- except Exception as e:
328
- print(f"Error initializing annotations dataset: {e}")
329
- print("Continuing with local-only mode")
330
-
331
- def push_empty_dataset(self):
332
- """Create and push empty dataset structure"""
333
- try:
334
- empty_data = {
335
- "user_id": [],
336
- "sample_id": [],
337
- "original_id": [],
338
- "model_a": [],
339
- "model_b": [],
340
- "choice": [],
341
- "prompt": [],
342
- "response_a": [],
343
- "response_b": [],
344
- "dataset": [],
345
- "timestamp": []
346
  }
347
-
348
- dataset = Dataset.from_dict(empty_data)
349
- dataset.push_to_hub(OUTPUT_DATASET_NAME, token=HF_TOKEN, private=OUTPUT_DATASET_PRIVATE)
350
- print(f"Created empty dataset at {OUTPUT_DATASET_NAME}")
351
- except Exception as e:
352
- print(f"Error creating empty dataset: {e}")
353
-
354
- def load_existing_annotations(self):
355
- """Load existing annotations from HuggingFace dataset"""
356
- try:
357
- dataset = load_dataset(OUTPUT_DATASET_NAME, split="train", token=HF_TOKEN)
358
-
359
- # Rebuild annotations dictionary from dataset
360
- for item in dataset:
361
- user_id = item["user_id"]
362
- if user_id not in self.annotations:
363
- self.annotations[user_id] = []
364
-
365
- # Add to user's annotations
366
- self.annotations[user_id].append({
367
- "user_id": user_id,
368
- "sample_id": item["sample_id"],
369
- "choice": item["choice"],
370
- "model_a": item.get("model_a", ""),
371
- "model_b": item.get("model_b", ""),
372
- "timestamp": item["timestamp"]
373
- })
374
-
375
- # Update user state
376
- if user_id not in self.user_states:
377
- self.user_states[user_id] = {
378
- "current_index": 0,
379
- "annotations": []
380
- }
381
- if item["sample_id"] not in self.user_states[user_id]["annotations"]:
382
- self.user_states[user_id]["annotations"].append(item["sample_id"])
383
-
384
- print(f"Loaded {len(dataset)} existing annotations")
385
-
386
- except Exception as e:
387
- print(f"Error loading existing annotations: {e}")
388
- print("Starting with empty annotations")
389
 
390
  def get_user_seed(self, user_id: str) -> int:
391
  """Generate consistent seed for user"""
@@ -396,26 +392,38 @@ class AnnotationManager:
396
  seed = self.get_user_seed(user_id)
397
  samples = DATASET_SAMPLES.copy()
398
  random.Random(seed).shuffle(samples)
 
 
 
 
399
  return samples
400
 
401
  def get_next_sample(self, user_id: str) -> Tuple[Dict, int, int]:
402
  """Get next unannotated sample for user"""
403
  if user_id not in self.user_states:
404
- self.user_states[user_id] = {
405
- "current_index": 0,
406
- "annotations": []
407
- }
 
 
 
 
 
 
 
 
408
 
409
  samples = self.get_user_samples(user_id)
410
  state = self.user_states[user_id]
411
 
412
- # Count already annotated
413
- annotated_count = len(state["annotations"])
414
 
415
  # Find next unannotated sample
416
- for i, sample in enumerate(samples):
417
  if not self.is_annotated(user_id, sample["id"]):
418
- return sample, annotated_count + 1, len(samples)
419
 
420
  # All samples annotated
421
  return None, len(samples), len(samples)
@@ -427,91 +435,49 @@ class AnnotationManager:
427
  return any(ann["sample_id"] == sample_id for ann in self.annotations[user_id])
428
 
429
  def save_annotation(self, user_id: str, sample_id: str, choice: str,
430
- sample_data: Dict = None):
431
- """Save user's annotation locally and to HuggingFace"""
432
- with self.lock:
433
- if user_id not in self.annotations:
434
- self.annotations[user_id] = []
435
-
436
- annotation = {
437
- "user_id": user_id,
438
- "sample_id": sample_id,
439
- "choice": choice,
440
- "timestamp": datetime.now().isoformat()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
441
  }
442
-
443
- # Add sample data if provided
444
- if sample_data:
445
- annotation.update({
446
- "original_id": sample_data.get("original_id", ""),
447
- "model_a": sample_data.get("model_a", ""),
448
- "model_b": sample_data.get("model_b", ""),
449
- "prompt": sample_data.get("prompt", ""),
450
- "response_a": sample_data.get("response_a", ""),
451
- "response_b": sample_data.get("response_b", ""),
452
- "dataset": sample_data.get("dataset", "")
453
- })
454
-
455
- self.annotations[user_id].append(annotation)
456
-
457
- # Update user state
458
- if user_id in self.user_states:
459
- if sample_id not in self.user_states[user_id]["annotations"]:
460
- self.user_states[user_id]["annotations"].append(sample_id)
461
- self.user_states[user_id]["current_index"] += 1
462
-
463
- print(f"Saved annotation locally: {annotation['sample_id']} by {user_id}")
464
-
465
- # Save to HuggingFace asynchronously
466
- if HF_TOKEN:
467
- thread = threading.Thread(
468
- target=self.push_annotation_to_hub,
469
- args=(annotation,)
470
- )
471
- thread.daemon = True
472
- thread.start()
473
-
474
- def push_annotation_to_hub(self, annotation: Dict):
475
- """Push single annotation to HuggingFace dataset"""
476
- try:
477
- # Load current dataset
478
- dataset = load_dataset(OUTPUT_DATASET_NAME, split="train", token=HF_TOKEN)
479
-
480
- # Convert to dict
481
- data_dict = dataset.to_dict()
482
-
483
- # Ensure all keys exist
484
- required_keys = ["user_id", "sample_id", "original_id", "model_a",
485
- "model_b", "choice", "prompt", "response_a",
486
- "response_b", "dataset", "timestamp"]
487
-
488
- for key in required_keys:
489
- if key not in data_dict:
490
- data_dict[key] = []
491
- # Append new annotation data
492
- data_dict[key].append(annotation.get(key, ""))
493
-
494
- # Create new dataset and push
495
- updated_dataset = Dataset.from_dict(data_dict)
496
- updated_dataset.push_to_hub(
497
- OUTPUT_DATASET_NAME,
498
- token=HF_TOKEN,
499
- private=OUTPUT_DATASET_PRIVATE
500
- )
501
-
502
- print(f"Successfully pushed annotation to hub: {annotation['sample_id']}")
503
-
504
- except Exception as e:
505
- print(f"Error pushing annotation to hub: {e}")
506
- # Add to cache for batch upload later
507
- self.annotation_cache.append(annotation)
508
 
509
  def get_user_progress(self, user_id: str) -> Dict:
510
  """Get user's annotation progress"""
511
- if user_id not in self.user_states:
512
  return {"completed": 0, "total": len(DATASET_SAMPLES)}
513
 
514
- completed = len(self.user_states[user_id]["annotations"])
515
  return {"completed": completed, "total": len(DATASET_SAMPLES)}
516
 
517
 
@@ -540,13 +506,16 @@ def login(user_id: str) -> Tuple:
540
  gr.update(visible=True), # login_interface
541
  gr.update(visible=False), # annotation_interface
542
  user_id, # user_state
543
- gr.update(value=f"All samples completed for user: {user_id}"), # login_status
544
  gr.update(), # prompt
545
  gr.update(), # response_a
546
  gr.update(), # response_b
547
  gr.update() # progress
548
  )
549
 
 
 
 
550
  return (
551
  gr.update(visible=False), # login_interface
552
  gr.update(visible=True), # annotation_interface
@@ -555,8 +524,7 @@ def login(user_id: str) -> Tuple:
555
  gr.update(value=sample["prompt"]), # prompt
556
  gr.update(value=sample["response_a"]), # response_a
557
  gr.update(value=sample["response_b"]), # response_b
558
- gr.update(value=f"Progress: {current}/{total} | Comparing: {sample.get('model_a', 'A')} vs {sample.get('model_b', 'B')}") # progress
559
- # gr.update(value=f"Progress: {current}/{total}") # progress
560
  )
561
 
562
  def annotate(choice: str, user_id: str) -> Tuple:
@@ -579,13 +547,15 @@ def annotate(choice: str, user_id: str) -> Tuple:
579
  "b_better": "B is more fluent",
580
  "equal": "Equally fluent"
581
  }
582
-
583
- # Save with full sample data for HuggingFace dataset
584
  manager.save_annotation(
585
- user_id,
586
- sample["id"],
587
- choice_map[choice],
588
- sample_data=sample # Pass the full sample data
 
 
 
589
  )
590
 
591
  # Get next sample
@@ -607,8 +577,7 @@ def annotate(choice: str, user_id: str) -> Tuple:
607
  gr.update(value=next_sample["prompt"]), # prompt
608
  gr.update(value=next_sample["response_a"]), # response_a
609
  gr.update(value=next_sample["response_b"]), # response_b
610
- gr.update(value=f"Progress: {current}/{total} | Comparing: {sample.get('model_a', 'A')} vs {sample.get('model_b', 'B')}"),
611
- # gr.update(value=f"Progress: {current}/{total}{model_info}"), # progress
612
  gr.update(value="Annotation saved!", visible=True) # status
613
  )
614
 
 
9
  import hashlib
10
  import itertools
11
  from datasets import load_dataset, Dataset, DatasetDict
12
+ from huggingface_hub import HfApi, create_repo, repo_exists, Repository
13
+ import shutil
14
  import threading
15
 
16
  from collections.abc import Iterable
 
219
  """
220
 
221
  # Configuration for the output dataset
222
+ ANNOTATIONS_REPO = "ltg/fluency-annotations" # Change to your repo name
223
+ ANNOTATIONS_FILE = "train.jsonl"
224
 
225
  HF_TOKEN = os.environ.get("HF_TOKEN")
226
 
 
230
  # Create all pairwise comparisons
231
  MODEL_PAIRS = list(itertools.combinations(MODEL_NAMES, 2))
232
 
233
+ # Initialize repository
234
+ def init_repository():
235
+ """Initialize or clone the repository"""
236
+ try:
237
+ repo = Repository(
238
+ local_dir=DATA_DIR,
239
+ clone_from=ANNOTATIONS_REPO,
240
+ use_auth_token=HF_TOKEN,
241
+ repo_type="dataset"
242
+ )
243
+ repo.git_pull()
244
+ return repo
245
+ except Exception as e:
246
+ print(f"Error initializing repository: {e}")
247
+ # Create local directory if repo doesn't exist
248
+ os.makedirs(DATA_DIR, exist_ok=True)
249
+ return None
250
+
251
+ # Initialize on startup
252
+ annotation_repo = init_repository()
253
+
254
+ def load_existing_annotations():
255
+ """Load existing annotations from the jsonl file"""
256
+ annotations = {}
257
+
258
+ if os.path.exists(ANNOTATIONS_FILE):
259
+ try:
260
+ with open(ANNOTATIONS_FILE, "r") as f:
261
+ for line in f:
262
+ if line.strip():
263
+ ann = json.loads(line)
264
+ user_id = ann.get("user_id")
265
+ if user_id:
266
+ if user_id not in annotations:
267
+ annotations[user_id] = []
268
+ annotations[user_id].append(ann)
269
+ print(f"Loaded {sum(len(v) for v in annotations.values())} existing annotations")
270
+ except Exception as e:
271
+ print(f"Error loading annotations: {e}")
272
+
273
+ return annotations
274
+
275
+ def save_annotation_to_file(annotation_data):
276
+ """Save a single annotation to the jsonl file and push to hub"""
277
+ global annotation_repo
278
+
279
+ try:
280
+ # Pull latest changes
281
+ if annotation_repo:
282
+ annotation_repo.git_pull()
283
+
284
+ # Append to jsonl file
285
+ with open(ANNOTATIONS_FILE, "a") as f:
286
+ line = json.dumps(annotation_data, ensure_ascii=False)
287
+ f.write(f"{line}\n")
288
+
289
+ # Push to hub asynchronously
290
+ if annotation_repo:
291
+ annotation_repo.push_to_hub(blocking=False, commit_message="Add annotation")
292
+
293
+ except Exception as e:
294
+ print(f"Error saving annotation: {e}")
295
+ # Try to reinitialize repository
296
+ try:
297
+ shutil.rmtree(DATA_DIR)
298
+ annotation_repo = init_repository()
299
+
300
+ # Retry saving
301
+ with open(ANNOTATIONS_FILE, "a") as f:
302
+ line = json.dumps(annotation_data, ensure_ascii=False)
303
+ f.write(f"{line}\n")
304
+
305
+ if annotation_repo:
306
+ annotation_repo.push_to_hub(blocking=False, commit_message="Add annotation")
307
+ except Exception as e2:
308
+ print(f"Failed to save annotation after retry: {e2}")
309
+
310
  def load_dataset_samples():
311
  """Load and prepare dataset samples with pairwise comparisons"""
312
  try:
 
371
 
372
  class AnnotationManager:
373
  def __init__(self):
374
+ # Load existing annotations from file
375
+ self.annotations = load_existing_annotations()
376
+ self.user_states = {}
 
 
 
 
377
 
378
+ # Rebuild user states from loaded annotations
379
+ for user_id, user_annotations in self.annotations.items():
380
+ annotated_ids = [ann["sample_id"] for ann in user_annotations]
381
+ self.user_states[user_id] = {
382
+ "current_index": 0,
383
+ "annotations": annotated_ids
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
385
 
386
  def get_user_seed(self, user_id: str) -> int:
387
  """Generate consistent seed for user"""
 
392
  seed = self.get_user_seed(user_id)
393
  samples = DATASET_SAMPLES.copy()
394
  random.Random(seed).shuffle(samples)
395
+ samples = [
396
+ sample if random.Random(seed + i).randint(0, 1) == 0 else swap_sample(sample)
397
+ for i, sample in enumerate(samples)
398
+ ]
399
  return samples
400
 
401
  def get_next_sample(self, user_id: str) -> Tuple[Dict, int, int]:
402
  """Get next unannotated sample for user"""
403
  if user_id not in self.user_states:
404
+ # Check if user has existing annotations
405
+ if user_id in self.annotations:
406
+ annotated_ids = [ann["sample_id"] for ann in self.annotations[user_id]]
407
+ self.user_states[user_id] = {
408
+ "current_index": 0,
409
+ "annotations": annotated_ids
410
+ }
411
+ else:
412
+ self.user_states[user_id] = {
413
+ "current_index": 0,
414
+ "annotations": []
415
+ }
416
 
417
  samples = self.get_user_samples(user_id)
418
  state = self.user_states[user_id]
419
 
420
+ # Count total annotations for this user
421
+ total_annotated = len(state["annotations"])
422
 
423
  # Find next unannotated sample
424
+ for idx, sample in enumerate(samples):
425
  if not self.is_annotated(user_id, sample["id"]):
426
+ return sample, total_annotated + 1, len(samples)
427
 
428
  # All samples annotated
429
  return None, len(samples), len(samples)
 
435
  return any(ann["sample_id"] == sample_id for ann in self.annotations[user_id])
436
 
437
  def save_annotation(self, user_id: str, sample_id: str, choice: str,
438
+ model_a: str = None, model_b: str = None,
439
+ original_id: str = None, dataset_name: str = None):
440
+ """Save user's annotation and persist to file"""
441
+ if user_id not in self.annotations:
442
+ self.annotations[user_id] = []
443
+
444
+ annotation = {
445
+ "user_id": user_id,
446
+ "sample_id": sample_id,
447
+ "original_sample_id": original_id,
448
+ "dataset": dataset_name,
449
+ "model_a": model_a,
450
+ "model_b": model_b,
451
+ "choice": choice,
452
+ "timestamp": datetime.now().isoformat()
453
+ }
454
+
455
+ # Save to memory
456
+ self.annotations[user_id].append(annotation)
457
+
458
+ # Update user state
459
+ if user_id in self.user_states:
460
+ self.user_states[user_id]["annotations"].append(sample_id)
461
+ else:
462
+ self.user_states[user_id] = {
463
+ "current_index": 0,
464
+ "annotations": [sample_id]
465
  }
466
+
467
+ # Save to file asynchronously
468
+ threading.Thread(
469
+ target=save_annotation_to_file,
470
+ args=(annotation,)
471
+ ).start()
472
+
473
+ print(f"Saved annotation: {annotation}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
474
 
475
  def get_user_progress(self, user_id: str) -> Dict:
476
  """Get user's annotation progress"""
477
+ if user_id not in self.annotations:
478
  return {"completed": 0, "total": len(DATASET_SAMPLES)}
479
 
480
+ completed = len(self.annotations[user_id])
481
  return {"completed": completed, "total": len(DATASET_SAMPLES)}
482
 
483
 
 
506
  gr.update(visible=True), # login_interface
507
  gr.update(visible=False), # annotation_interface
508
  user_id, # user_state
509
+ gr.update(value=f"All {total} samples completed for user: {user_id}! 🎉"), # login_status
510
  gr.update(), # prompt
511
  gr.update(), # response_a
512
  gr.update(), # response_b
513
  gr.update() # progress
514
  )
515
 
516
+ # Show which models are being compared
517
+ model_info = f" | Comparing: {sample.get('model_a', 'A')} vs {sample.get('model_b', 'B')}"
518
+
519
  return (
520
  gr.update(visible=False), # login_interface
521
  gr.update(visible=True), # annotation_interface
 
524
  gr.update(value=sample["prompt"]), # prompt
525
  gr.update(value=sample["response_a"]), # response_a
526
  gr.update(value=sample["response_b"]), # response_b
527
+ gr.update(value=f"Progress: {current}/{total}{model_info}") # progress
 
528
  )
529
 
530
  def annotate(choice: str, user_id: str) -> Tuple:
 
547
  "b_better": "B is more fluent",
548
  "equal": "Equally fluent"
549
  }
550
+ # Save with all metadata
 
551
  manager.save_annotation(
552
+ user_id=user_id,
553
+ sample_id=sample["id"],
554
+ choice=choice_map[choice],
555
+ model_a=sample.get("model_a"),
556
+ model_b=sample.get("model_b"),
557
+ original_id=sample.get("original_id"),
558
+ dataset_name=sample.get("dataset")
559
  )
560
 
561
  # Get next sample
 
577
  gr.update(value=next_sample["prompt"]), # prompt
578
  gr.update(value=next_sample["response_a"]), # response_a
579
  gr.update(value=next_sample["response_b"]), # response_b
580
+ gr.update(value=f"Progress: {current}/{total}{model_info}"), # progress
 
581
  gr.update(value="Annotation saved!", visible=True) # status
582
  )
583