amanuelbyte commited on
Commit
7b94069
·
verified ·
1 Parent(s): 10c3b7d

Add README model card

Browse files
Files changed (1) hide show
  1. README.md +54 -1
README.md CHANGED
@@ -1,3 +1,4 @@
 
1
  ---
2
  language: am
3
  license: apache-2.0
@@ -6,4 +7,56 @@ tags:
6
  - text-generation
7
  - custom-model
8
  - hrm
9
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
  ---
3
  language: am
4
  license: apache-2.0
 
7
  - text-generation
8
  - custom-model
9
  - hrm
10
+ ---
11
+
12
+ # HRM-Text1 Amharic Model
13
+
14
+ This is a custom text generation model based on the Hierarchical Recurrent Memory (HRM) architecture. It was trained from scratch on the `amanuelbyte/Amharic_dataset`.
15
+
16
+ **This is a custom model and requires `trust_remote_code=True` to load.**
17
+
18
+ ## How to Use
19
+
20
+ Because this is a custom architecture, you need to load the model by importing the `HRMText1` class from the `hrm_model.py` file.
21
+
22
+ ```python
23
+ import torch
24
+ from transformers import T5Tokenizer
25
+ from huggingface_hub import hf_hub_download
26
+ from hrm_model import HRMText1 # Import the custom class
27
+ import json
28
+
29
+ # Replace with your repo ID
30
+ repo_id = "amanuelbyte/HRM-amharic"
31
+ device = "cuda" if torch.cuda.is_available() else "cpu"
32
+
33
+ # 1. Load the tokenizer
34
+ tokenizer = T5Tokenizer.from_pretrained(repo_id)
35
+
36
+ # 2. Load the model's configuration
37
+ config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
38
+ with open(config_path, 'r') as f:
39
+ config = json.load(f)
40
+
41
+ # 3. Instantiate the model with the config
42
+ # The trust_remote_code=True is not strictly needed here because we import manually,
43
+ # but it's good practice for custom models.
44
+ model = HRMText1(config)
45
+
46
+ # 4. Load the model weights
47
+ weights_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin")
48
+ state_dict = torch.load(weights_path, map_location=device)
49
+ model.load_state_dict(state_dict)
50
+ model.to(device)
51
+ model.eval()
52
+
53
+ print("Model loaded successfully!")
54
+
55
+ # Now you can use the model for generation...
56
+ prompt = "የኢትዮጵያ ዋና ከተማ"
57
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
58
+
59
+ with torch.inference_mode():
60
+ output_ids = model.generate(input_ids, max_new_tokens=50) # Assuming a generate method exists
61
+
62
+ print(tokenizer.decode(output_ids, skip_special_tokens=True))