File size: 8,204 Bytes
278ebdd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61585a2
278ebdd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61585a2
278ebdd
 
 
 
 
 
61585a2
278ebdd
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
"""
πŸ”§ LoRA Merger Space
FSDP 체크포인트λ₯Ό λ‹€μš΄λ°›μ•„ 베이슀 λͺ¨λΈκ³Ό 병합 ν›„ Hub에 μ—…λ‘œλ“œν•©λ‹ˆλ‹€.
"""

import os
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from huggingface_hub import snapshot_download, HfApi, login
import logging

# Logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s | %(message)s')
logger = logging.getLogger(__name__)

# Configuration
SOURCE_REPO = "ongilLabs/IB-Math-Ontology-7B"  # LoRA adapter
BASE_MODEL = "Qwen/Qwen2.5-Math-7B-Instruct"
OUTPUT_REPO = "ongilLabs/IB-Math-Ontology-7B"  # Merged model output

def merge_model(progress=gr.Progress()):
    """메인 병합 ν•¨μˆ˜"""
    logs = []
    
    def log(msg):
        logger.info(msg)
        logs.append(msg)
        return "\n".join(logs)
    
    try:
        # Step 1: Download checkpoint
        progress(0.1, desc="πŸ“₯ Downloading checkpoint...")
        log("πŸ“₯ Downloading checkpoint from Hub...")
        
        local_dir = snapshot_download(
            repo_id=SOURCE_REPO,
            local_dir="/tmp/checkpoint",
            token=os.getenv("HF_TOKEN")
        )
        log(f"   Downloaded to: {local_dir}")
        
        # Step 2: Find adapter
        progress(0.2, desc="πŸ” Finding adapter...")
        adapter_path = None
        
        # Check locations
        for path in [f"{local_dir}/last-checkpoint", local_dir]:
            if os.path.exists(f"{path}/adapter_config.json"):
                adapter_path = path
                log(f"βœ… Found adapter at: {path}")
                break
        
        if not adapter_path:
            # List files for debugging
            log("❌ adapter_config.json not found!")
            log("πŸ“‚ Available files:")
            for root, dirs, files in os.walk(local_dir):
                for f in files:
                    rel_path = os.path.relpath(os.path.join(root, f), local_dir)
                    log(f"   - {rel_path}")
            return "\n".join(logs) + "\n\n❌ FAILED: No adapter found"
        
        # Step 3: Load base model
        progress(0.3, desc="πŸ“¦ Loading base model...")
        log(f"πŸ“¦ Loading base model: {BASE_MODEL}")
        log("   This may take 3-5 minutes...")
        
        base_model = AutoModelForCausalLM.from_pretrained(
            BASE_MODEL,
            torch_dtype=torch.bfloat16,
            device_map="auto",
            trust_remote_code=True,
        )
        log("   βœ… Base model loaded!")
        
        # Step 4: Load tokenizer
        progress(0.4, desc="πŸ“ Loading tokenizer...")
        log("πŸ“ Loading tokenizer...")
        tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
        log("   βœ… Tokenizer loaded!")
        
        # Step 5: Load LoRA adapter
        progress(0.5, desc="πŸ”— Loading LoRA adapter...")
        log(f"πŸ”— Loading LoRA adapter from: {adapter_path}")
        
        model = PeftModel.from_pretrained(
            base_model,
            adapter_path,
            torch_dtype=torch.bfloat16,
        )
        log("   βœ… LoRA adapter loaded!")
        
        # Step 6: Merge
        progress(0.6, desc="πŸ”§ Merging LoRA with base model...")
        log("πŸ”§ Merging LoRA weights with base model...")
        model = model.merge_and_unload()
        log("   βœ… Merge complete!")
        
        # Step 7: Save
        progress(0.7, desc="πŸ’Ύ Saving merged model...")
        output_dir = "/tmp/merged_model"
        log(f"πŸ’Ύ Saving merged model to: {output_dir}")
        
        os.makedirs(output_dir, exist_ok=True)
        model.save_pretrained(output_dir, safe_serialization=True, max_shard_size="5GB")
        tokenizer.save_pretrained(output_dir)
        
        # List saved files
        log("   πŸ“‚ Saved files:")
        for f in os.listdir(output_dir):
            size_mb = os.path.getsize(os.path.join(output_dir, f)) / (1024 * 1024)
            log(f"      - {f}: {size_mb:.1f} MB")
        
        # Step 8: Create model card
        progress(0.8, desc="πŸ“ Creating model card...")
        log("πŸ“ Creating model card...")
        
        model_card = """---
license: apache-2.0
base_model: Qwen/Qwen2.5-Math-7B-Instruct
tags:
  - math
  - ib-mathematics
  - qwen2
  - fine-tuned
  - education
  - ontology
  - chain-of-thought
language:
  - en
pipeline_tag: text-generation
---

# IB-Math-Ontology-7B

Fine-tuned Qwen2.5-Math-7B-Instruct for IB Mathematics AA with ontology-based Chain-of-Thought reasoning.

## Features
- 🎯 **IB Math AA Specialized**: Trained on 1,332 ontology-based examples
- πŸ’­ **Chain-of-Thought**: Uses `<think>` tags for step-by-step reasoning
- πŸ“š **Curriculum-Aligned**: Covers all 5 IB Math AA topics
- ⚠️ **Pitfall Awareness**: Warns about common student mistakes

## Usage

```python
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("ongilLabs/IB-Math-Ontology-7B", torch_dtype="auto", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("ongilLabs/IB-Math-Ontology-7B")

prompt = "Find the derivative of f(x) = xΒ³ - 2xΒ² + 5x [6 marks]"
messages = [
    {"role": "system", "content": "You are an expert IB Mathematics AA tutor. Think step-by-step and explain concepts clearly."},
    {"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(text, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=512)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
```

## Training Details
- **Base Model**: Qwen2.5-Math-7B-Instruct
- **Method**: LoRA (r=64, alpha=128)
- **Dataset**: 1,332 IB Math Ontology examples with CoT
- **Hardware**: NVIDIA A100 (80GB)
- **Epochs**: 3
- **Precision**: BF16
"""
        
        with open(os.path.join(output_dir, "README.md"), "w") as f:
            f.write(model_card)
        log("   βœ… Model card created!")
        
        # Step 9: Upload to Hub
        progress(0.9, desc="πŸš€ Uploading to Hub...")
        log(f"πŸš€ Uploading to Hub: {OUTPUT_REPO}")
        
        api = HfApi(token=os.getenv("HF_TOKEN"))
        api.upload_folder(
            folder_path=output_dir,
            repo_id=OUTPUT_REPO,
            commit_message="✨ Merged LoRA with base model - Production ready",
        )
        
        log(f"   βœ… Uploaded to: https://huggingface.co/{OUTPUT_REPO}")
        
        # Done!
        progress(1.0, desc="πŸŽ‰ Complete!")
        log("")
        log("=" * 50)
        log("πŸŽ‰ SUCCESS! Model merged and uploaded!")
        log("=" * 50)
        log(f"πŸ“ Model URL: https://huggingface.co/{OUTPUT_REPO}")
        
        return "\n".join(logs)
        
    except Exception as e:
        log(f"\n❌ ERROR: {str(e)}")
        import traceback
        log(traceback.format_exc())
        return "\n".join(logs)


def create_ui():
    """Gradio UI 생성"""
    with gr.Blocks(title="LoRA Merger") as app:
        gr.Markdown("""
        # πŸ”§ IB-Math-Ontology LoRA Merger
        
        This Space merges the LoRA adapter with the base model.
        
        **Source**: `ongilLabs/IB-Math-Ontology-7B` (LoRA adapter)  
        **Base**: `Qwen/Qwen2.5-Math-7B-Instruct`  
        **Output**: `ongilLabs/IB-Math-Ontology-7B` (merged model)
        
        **Steps:**
        1. Download LoRA checkpoint from Hub
        2. Load base model (Qwen2.5-Math-7B-Instruct)
        3. Load LoRA adapter
        4. Merge LoRA weights into base model
        5. Upload merged model to Hub
        """)
        
        with gr.Row():
            merge_btn = gr.Button("πŸš€ Start Merge", variant="primary", scale=2)
        
        output = gr.Textbox(
            label="Logs",
            lines=30,
            max_lines=50
        )
        
        merge_btn.click(fn=merge_model, outputs=output)
        
        gr.Markdown("""
        ---
        **Note**: This process takes about 10-15 minutes. Make sure you have enough GPU memory.
        """)
    
    return app


if __name__ == "__main__":
    app = create_ui()
    app.launch()