Update AGIFORMER with Turkish benchmark
Browse files- generate.py +28 -12
generate.py
CHANGED
|
@@ -22,7 +22,9 @@ def generate_text(model_path, prompt_text, max_new_tokens=200, temperature=0.8):
|
|
| 22 |
model.load_state_dict(state_dict)
|
| 23 |
model.eval()
|
| 24 |
|
| 25 |
-
|
|
|
|
|
|
|
| 26 |
pad_len = (PATCH_SIZE - (len(input_bytes) % PATCH_SIZE)) % PATCH_SIZE
|
| 27 |
if pad_len > 0:
|
| 28 |
input_bytes.extend([32] * pad_len)
|
|
@@ -44,18 +46,32 @@ def generate_text(model_path, prompt_text, max_new_tokens=200, temperature=0.8):
|
|
| 44 |
last_patch = pred_patches[0, -1, :].cpu().tolist()
|
| 45 |
generated.extend(last_patch)
|
| 46 |
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
decoded_str += chr(b)
|
| 51 |
-
else:
|
| 52 |
-
# Simple representation for non-printables
|
| 53 |
-
pass
|
| 54 |
-
|
| 55 |
-
print(decoded_str, end='', flush=True)
|
| 56 |
|
| 57 |
print("\n" + "-" * 50)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
if __name__ == "__main__":
|
| 60 |
-
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
model.load_state_dict(state_dict)
|
| 23 |
model.eval()
|
| 24 |
|
| 25 |
+
# Encode prompt to UTF-8 bytes
|
| 26 |
+
input_bytes = list(prompt_text.encode('utf-8'))
|
| 27 |
+
|
| 28 |
pad_len = (PATCH_SIZE - (len(input_bytes) % PATCH_SIZE)) % PATCH_SIZE
|
| 29 |
if pad_len > 0:
|
| 30 |
input_bytes.extend([32] * pad_len)
|
|
|
|
| 46 |
last_patch = pred_patches[0, -1, :].cpu().tolist()
|
| 47 |
generated.extend(last_patch)
|
| 48 |
|
| 49 |
+
# Real-time decoding for display is tricky with multi-byte chars
|
| 50 |
+
# We'll just collect and decode at the end or try best effort
|
| 51 |
+
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
print("\n" + "-" * 50)
|
| 54 |
+
try:
|
| 55 |
+
full_text = bytes(generated).decode('utf-8', errors='replace')
|
| 56 |
+
# Print only the new part
|
| 57 |
+
print(full_text[len(prompt_text):])
|
| 58 |
+
except:
|
| 59 |
+
print("\n[Decoding Error]")
|
| 60 |
|
| 61 |
if __name__ == "__main__":
|
| 62 |
+
import argparse
|
| 63 |
+
|
| 64 |
+
parser = argparse.ArgumentParser(description='Generate text with AGIFORMER')
|
| 65 |
+
parser.add_argument('--prompt', type=str, default="The history of ", help='Text prompt to start generation')
|
| 66 |
+
parser.add_argument('--temp', type=float, default=0.7, help='Sampling temperature')
|
| 67 |
+
parser.add_argument('--model', type=str, default="best_model.pth", help='Path to model checkpoint')
|
| 68 |
+
|
| 69 |
+
args = parser.parse_args()
|
| 70 |
+
|
| 71 |
+
# Check if user meant to use the Turkish model but it's named differently
|
| 72 |
+
model_path = args.model
|
| 73 |
+
if not os.path.exists(model_path) and os.path.exists("best_model_turkish.pth"):
|
| 74 |
+
print(f"Note: '{model_path}' not found, using 'best_model_turkish.pth' instead.")
|
| 75 |
+
model_path = "best_model_turkish.pth"
|
| 76 |
+
|
| 77 |
+
generate_text(model_path, args.prompt, temperature=args.temp)
|