nexusbert commited on
Commit
c4efecf
·
1 Parent(s): 2d6c16e

push max token

Browse files
Files changed (1) hide show
  1. app.py +24 -10
app.py CHANGED
@@ -391,21 +391,27 @@ Produce ONLY valid JSON with these exact fields:
391
  prompt_tokens = tokenizer.encode(prompt, return_tensors="pt")
392
  prompt_token_count = prompt_tokens.shape[1]
393
 
 
394
  max_input_tokens = 3800
395
- max_output_tokens = 4096 - max_input_tokens
396
 
397
  if prompt_token_count > max_input_tokens:
398
  logger.warning(f"Prompt is {prompt_token_count} tokens, truncating to {max_input_tokens}")
399
  prompt_tokens = prompt_tokens[:, :max_input_tokens]
400
  prompt = tokenizer.decode(prompt_tokens[0], skip_special_tokens=True)
 
401
 
402
- logger.info(f"Input tokens: ~{prompt_token_count}, Max output tokens: {max_output_tokens}")
403
 
404
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=max_input_tokens).to(model.device)
 
 
 
 
 
405
 
406
  outputs = model.generate(
407
  **inputs,
408
- max_new_tokens=min(1500, max_output_tokens),
409
  temperature=0.3,
410
  do_sample=True,
411
  top_p=0.95,
@@ -462,19 +468,23 @@ Produce a FINAL comprehensive review with the same JSON structure as before, con
462
  prompt_tokens = tokenizer.encode(prompt, return_tensors="pt")
463
  prompt_token_count = prompt_tokens.shape[1]
464
 
 
465
  max_input_tokens = 3800
466
- max_output_tokens = 4096 - max_input_tokens
467
 
468
  if prompt_token_count > max_input_tokens:
469
  logger.warning(f"Combine prompt is {prompt_token_count} tokens, truncating to {max_input_tokens}")
470
  prompt_tokens = prompt_tokens[:, :max_input_tokens]
471
  prompt = tokenizer.decode(prompt_tokens[0], skip_special_tokens=True)
 
472
 
473
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=max_input_tokens).to(model.device)
474
 
 
 
 
475
  outputs = model.generate(
476
  **inputs,
477
- max_new_tokens=min(1500, max_output_tokens),
478
  temperature=0.3,
479
  do_sample=True,
480
  top_p=0.95,
@@ -583,19 +593,23 @@ Return ONLY valid JSON:
583
  prompt_tokens = tokenizer.encode(prompt, return_tensors="pt")
584
  prompt_token_count = prompt_tokens.shape[1]
585
 
 
586
  max_input_tokens = 3800
587
- max_output_tokens = 4096 - max_input_tokens
588
 
589
  if prompt_token_count > max_input_tokens:
590
  logger.warning(f"Improvement prompt is {prompt_token_count} tokens, truncating to {max_input_tokens}")
591
  prompt_tokens = prompt_tokens[:, :max_input_tokens]
592
  prompt = tokenizer.decode(prompt_tokens[0], skip_special_tokens=True)
 
 
 
593
 
594
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=max_input_tokens).to(model.device)
595
 
 
596
  outputs = model.generate(
597
  **inputs,
598
- max_new_tokens=min(1000, max_output_tokens),
599
  temperature=0.4,
600
  do_sample=True,
601
  top_p=0.95,
 
391
  prompt_tokens = tokenizer.encode(prompt, return_tensors="pt")
392
  prompt_token_count = prompt_tokens.shape[1]
393
 
394
+ max_context = 4096
395
  max_input_tokens = 3800
 
396
 
397
  if prompt_token_count > max_input_tokens:
398
  logger.warning(f"Prompt is {prompt_token_count} tokens, truncating to {max_input_tokens}")
399
  prompt_tokens = prompt_tokens[:, :max_input_tokens]
400
  prompt = tokenizer.decode(prompt_tokens[0], skip_special_tokens=True)
401
+ prompt_token_count = max_input_tokens
402
 
403
+ max_output_tokens = max_context - prompt_token_count - 50
404
 
405
+ logger.info(f"Input tokens: {prompt_token_count}, Available output tokens: {max_output_tokens}")
406
+
407
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=prompt_token_count).to(model.device)
408
+
409
+ output_limit = min(1500, max_output_tokens)
410
+ logger.info(f"Setting max_new_tokens to {output_limit}")
411
 
412
  outputs = model.generate(
413
  **inputs,
414
+ max_new_tokens=output_limit,
415
  temperature=0.3,
416
  do_sample=True,
417
  top_p=0.95,
 
468
  prompt_tokens = tokenizer.encode(prompt, return_tensors="pt")
469
  prompt_token_count = prompt_tokens.shape[1]
470
 
471
+ max_context = 4096
472
  max_input_tokens = 3800
 
473
 
474
  if prompt_token_count > max_input_tokens:
475
  logger.warning(f"Combine prompt is {prompt_token_count} tokens, truncating to {max_input_tokens}")
476
  prompt_tokens = prompt_tokens[:, :max_input_tokens]
477
  prompt = tokenizer.decode(prompt_tokens[0], skip_special_tokens=True)
478
+ prompt_token_count = max_input_tokens
479
 
480
+ max_output_tokens = max_context - prompt_token_count - 50
481
 
482
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=prompt_token_count).to(model.device)
483
+
484
+ output_limit = min(1500, max_output_tokens)
485
  outputs = model.generate(
486
  **inputs,
487
+ max_new_tokens=output_limit,
488
  temperature=0.3,
489
  do_sample=True,
490
  top_p=0.95,
 
593
  prompt_tokens = tokenizer.encode(prompt, return_tensors="pt")
594
  prompt_token_count = prompt_tokens.shape[1]
595
 
596
+ max_context = 4096
597
  max_input_tokens = 3800
 
598
 
599
  if prompt_token_count > max_input_tokens:
600
  logger.warning(f"Improvement prompt is {prompt_token_count} tokens, truncating to {max_input_tokens}")
601
  prompt_tokens = prompt_tokens[:, :max_input_tokens]
602
  prompt = tokenizer.decode(prompt_tokens[0], skip_special_tokens=True)
603
+ prompt_token_count = max_input_tokens
604
+
605
+ max_output_tokens = max_context - prompt_token_count - 50
606
 
607
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=prompt_token_count).to(model.device)
608
 
609
+ output_limit = min(1000, max_output_tokens)
610
  outputs = model.generate(
611
  **inputs,
612
+ max_new_tokens=output_limit,
613
  temperature=0.4,
614
  do_sample=True,
615
  top_p=0.95,