tefoteknik commited on
Commit
c31993e
·
verified ·
1 Parent(s): 74e89c5

Update AGIFORMER with Turkish benchmark

Browse files
Files changed (1) hide show
  1. generate.py +28 -12
generate.py CHANGED
@@ -22,7 +22,9 @@ def generate_text(model_path, prompt_text, max_new_tokens=200, temperature=0.8):
22
  model.load_state_dict(state_dict)
23
  model.eval()
24
 
25
- input_bytes = [ord(c) for c in prompt_text]
 
 
26
  pad_len = (PATCH_SIZE - (len(input_bytes) % PATCH_SIZE)) % PATCH_SIZE
27
  if pad_len > 0:
28
  input_bytes.extend([32] * pad_len)
@@ -44,18 +46,32 @@ def generate_text(model_path, prompt_text, max_new_tokens=200, temperature=0.8):
44
  last_patch = pred_patches[0, -1, :].cpu().tolist()
45
  generated.extend(last_patch)
46
 
47
- decoded_str = ""
48
- for b in last_patch:
49
- if 32 <= b <= 126 or b == 10 or b == 9:
50
- decoded_str += chr(b)
51
- else:
52
- # Simple representation for non-printables
53
- pass
54
-
55
- print(decoded_str, end='', flush=True)
56
 
57
  print("\n" + "-" * 50)
 
 
 
 
 
 
58
 
59
  if __name__ == "__main__":
60
- # Test with a generic English prompt to see if it generalizes beyond XML
61
- generate_text("best_model.pth", "The history of ", temperature=0.7)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  model.load_state_dict(state_dict)
23
  model.eval()
24
 
25
+ # Encode prompt to UTF-8 bytes
26
+ input_bytes = list(prompt_text.encode('utf-8'))
27
+
28
  pad_len = (PATCH_SIZE - (len(input_bytes) % PATCH_SIZE)) % PATCH_SIZE
29
  if pad_len > 0:
30
  input_bytes.extend([32] * pad_len)
 
46
  last_patch = pred_patches[0, -1, :].cpu().tolist()
47
  generated.extend(last_patch)
48
 
49
+ # Real-time decoding for display is tricky with multi-byte chars
50
+ # We'll just collect and decode at the end or try best effort
51
+ pass
 
 
 
 
 
 
52
 
53
  print("\n" + "-" * 50)
54
+ try:
55
+ full_text = bytes(generated).decode('utf-8', errors='replace')
56
+ # Print only the new part
57
+ print(full_text[len(prompt_text):])
58
+ except:
59
+ print("\n[Decoding Error]")
60
 
61
  if __name__ == "__main__":
62
+ import argparse
63
+
64
+ parser = argparse.ArgumentParser(description='Generate text with AGIFORMER')
65
+ parser.add_argument('--prompt', type=str, default="The history of ", help='Text prompt to start generation')
66
+ parser.add_argument('--temp', type=float, default=0.7, help='Sampling temperature')
67
+ parser.add_argument('--model', type=str, default="best_model.pth", help='Path to model checkpoint')
68
+
69
+ args = parser.parse_args()
70
+
71
+ # Check if user meant to use the Turkish model but it's named differently
72
+ model_path = args.model
73
+ if not os.path.exists(model_path) and os.path.exists("best_model_turkish.pth"):
74
+ print(f"Note: '{model_path}' not found, using 'best_model_turkish.pth' instead.")
75
+ model_path = "best_model_turkish.pth"
76
+
77
+ generate_text(model_path, args.prompt, temperature=args.temp)