File size: 3,617 Bytes
6dbabcf
 
 
 
 
42bcacc
 
 
 
 
 
 
 
 
8d7daea
42bcacc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d7daea
42bcacc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b41e62c
6dbabcf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
---
base_model:
- Qwen/Qwen2.5-0.5B-Instruct
library_name: transformers
---

# ALM-Qwen Model: ALM-Qwen-0.5B-testing

This repository contains an Attention-Linked Memory augmented Qwen model (ALM-Qwen).

## Model Components

*   **AttentionLinkedMemory (ALM)**: A custom PyTorch module for two-level attention-based retrieval from structured memory. (See `ALM.py`)
*   **QwenGenerator**: Wraps a Hugging Face Qwen model (e.g., Qwen2.5-0.5B-Instruct or Qwen2.5-7B-Instruct) for text generation.
*   **ALMQwenModel_HF**: The main class orchestrating the ALM retrieval and Qwen generation. (See `alm_qwen.py`)
*   **Saved Weights & Config**:
    *   `alm_layer_state_dict.pth`: Trained weights for the ALM layer.
    *   `alm_qwen_hf_config.json`: Configuration for the `ALMQwenModel_HF`, including ALM parameters and paths to the Qwen components.
    *   `qwen_generator/`: Contains the saved Hugging Face Qwen model and tokenizer.

## How to Use

1.  **Prerequisites**:
    ```bash
    pip install torch transformers huggingface_hub sentencepiece accelerate
    # Add other dependencies if any, e.g., bitsandbytes for quantization
    ```

2.  **Clone the repository (or download files manually)**:
    ```bash
    git lfs install # if large files are used, though typically not for these components directly
    git clone https://huggingface.co/moelanoby/ALM-Qwen-0.5B-testing
    cd ALM-Qwen-0.5B-testing
    ```

3.  **Load the model in Python**:
    ```python
    from alm_qwen import ALMQwenModel_HF # Make sure alm_qwen_hf.py and ALM.py are in your PYTHONPATH
    import torch

    # Desired device
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Path to the directory where you cloned/downloaded the model
    model_directory = "." # Or the specific path if you are running from outside the cloned repo

    # Load the model
    loaded_model = ALMQwenModel_HF.load_model(model_directory, device=device)
    print("ALM-Qwen model loaded successfully!")

    # --- Prepare Dummy Input Data (similar to the example in alm_qwen_hf.py) ---
    # batch_size = 1
    # alm_query_dim = loaded_model.alm_config['query_dim']
    # alm_memory_dim = loaded_model.alm_config['memory_dim']
    # num_kb_buckets = 3 # Example
    # max_kb_items_per_bucket = 5 # Example

    # query_texts = ["What is the capital of France?"]
    # query_embeddings_for_alm = torch.randn(batch_size, alm_query_dim)
    # memory_item_embeddings = torch.randn(batch_size, num_kb_buckets, max_kb_items_per_bucket, alm_memory_dim)
    # memory_text_items = [[["Paris is the capital of France." for _ in range(max_kb_items_per_bucket)] for _ in range(num_kb_buckets)] for _ in range(batch_size)]
    # memory_mask = torch.ones(batch_size, num_kb_buckets, max_kb_items_per_bucket, dtype=torch.bool)
    # memory_mask[:, :, -1] = False # Example mask

    # # Run inference
    # generated_answers, _, _ = loaded_model(
    #     query_texts,
    #     query_embeddings_for_alm,
    #     memory_item_embeddings,
    #     memory_text_items,
    #     memory_mask
    # )
    # print(f"Query: {query_texts[0]}")
    # print(f"Answer: {generated_answers[0]}")
    ```

## Training

The ALM layer (`alm_layer_state_dict.pth`) might have been trained. The Qwen model inside `qwen_generator/` is typically a pre-trained model from Hugging Face, possibly fine-tuned.

## Notes

*   The Qwen model components can be large. Ensure you have sufficient disk space and network bandwidth.
*   The `load_model` method in `alm_qwen_hf.py` handles the reconstruction of the composite model.
*   If any errors happen use alm_qwen.py directly 
---