Thishyaketh commited on
Commit
8e36426
·
verified ·
1 Parent(s): 8365e4e

Upload 4 files

Browse files
converter.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import torch
4
+ from safetensors.torch import save_file
5
+
6
+ def extract_state_dict(checkpoint):
7
+ """
8
+ Extracts the tensor dictionary from common .pth formats.
9
+ """
10
+ if isinstance(checkpoint, dict):
11
+ for key in ["state_dict", "model", "model_state_dict", "module"]:
12
+ if key in checkpoint and isinstance(checkpoint[key], dict):
13
+ return checkpoint[key]
14
+ return checkpoint
15
+
16
+ def convert_pth_to_safetensors(input_path, output_path=None):
17
+ print(f"🔍 Loading checkpoint from: {input_path}")
18
+
19
+ try:
20
+ checkpoint = torch.load(input_path, map_location="cpu", weights_only=True)
21
+ except Exception as e:
22
+ print(f"❌ Error loading .pth file: {e}")
23
+ return
24
+
25
+ state_dict = extract_state_dict(checkpoint)
26
+
27
+ if not isinstance(state_dict, dict):
28
+ print("❌ Invalid checkpoint: not a dictionary.")
29
+ return
30
+
31
+ tensor_dict = {k: v for k, v in state_dict.items() if isinstance(v, torch.Tensor)}
32
+
33
+ if not tensor_dict:
34
+ print("❌ No tensor values found to convert.")
35
+ return
36
+
37
+ # Optionally add "model." prefix to HuggingFace-compatible keys
38
+ if not all(k.startswith("model.") for k in tensor_dict):
39
+ tensor_dict = {f"model.{k}": v for k, v in tensor_dict.items()}
40
+
41
+ if output_path is None:
42
+ output_path = os.path.splitext(input_path)[0] + ".safetensors"
43
+
44
+ try:
45
+ print(f"💾 Saving to: {output_path}")
46
+ save_file(tensor_dict, output_path)
47
+ print("✅ Conversion to .safetensors successful!")
48
+ except Exception as e:
49
+ print(f"❌ Saving failed: {e}")
50
+
51
+ if __name__ == "__main__":
52
+ parser = argparse.ArgumentParser(description="Convert .pth to .safetensors")
53
+ parser.add_argument("input", help="Path to input .pth file")
54
+ parser.add_argument("--output", help="Path to output .safetensors file (optional)")
55
+
56
+ args = parser.parse_args()
57
+ convert_pth_to_safetensors(args.input, args.output)
enhanced_transformer_model_500M_final.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ecf29d34e90a5bc9d26cd19401753a24b8dda6d9afc726681363023c26b6acb2
3
+ size 112823067
enhanced_transformer_model_500M_final.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f80dcb5311ccae55f66030c0b3285c37bae862d3cd48657a8dc33f04e6e3435c
3
+ size 112807876
textgen.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+
4
+ def load_model(model_path):
5
+ # Load the model
6
+ model = AutoModelForCausalLM.from_pretrained(model_path)
7
+ # Load the tokenizer (assuming it's saved alongside the model)
8
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
9
+ return model, tokenizer
10
+
11
+ def generate_text(model, tokenizer, prompt, max_length=100, temperature=0.7):
12
+ # Encode the prompt
13
+ input_ids = tokenizer.encode(prompt, return_tensors="pt")
14
+
15
+ # Generate text
16
+ with torch.no_grad():
17
+ output = model.generate(
18
+ input_ids,
19
+ max_length=max_length,
20
+ temperature=temperature,
21
+ num_return_sequences=1,
22
+ pad_token_id=tokenizer.eos_token_id
23
+ )
24
+
25
+ # Decode and return the generated text
26
+ generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
27
+ return generated_text
28
+
29
+ if __name__ == "__main__":
30
+ model_path = "enhanced_transformer_model_500M_final.pth"
31
+ model, tokenizer = load_model(model_path)
32
+
33
+ prompt = "Once upon a time"
34
+ generated_text = generate_text(model, tokenizer, prompt)
35
+ print(generated_text)