ongilLabs commited on
Commit
278ebdd
Β·
verified Β·
1 Parent(s): edbdbf9

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +251 -0
app.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ πŸ”§ LoRA Merger Space
3
+ FSDP 체크포인트λ₯Ό λ‹€μš΄λ°›μ•„ 베이슀 λͺ¨λΈκ³Ό 병합 ν›„ Hub에 μ—…λ‘œλ“œν•©λ‹ˆλ‹€.
4
+ """
5
+
6
+ import os
7
+ import torch
8
+ import gradio as gr
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer
10
+ from peft import PeftModel
11
+ from huggingface_hub import snapshot_download, HfApi, login
12
+ import logging
13
+
14
+ # Logging
15
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s | %(message)s')
16
+ logger = logging.getLogger(__name__)
17
+
18
+ # Configuration
19
+ SOURCE_REPO = "ongilLabs/IB-Math-Ontology-7B" # LoRA adapter
20
+ BASE_MODEL = "Qwen/Qwen2.5-Math-7B-Instruct"
21
+ OUTPUT_REPO = "ongilLabs/IB-Math-Ontology-7B" # Merged model output
22
+
23
+ def merge_model(progress=gr.Progress()):
24
+ """메인 병합 ν•¨μˆ˜"""
25
+ logs = []
26
+
27
+ def log(msg):
28
+ logger.info(msg)
29
+ logs.append(msg)
30
+ return "\n".join(logs)
31
+
32
+ try:
33
+ # Step 1: Download checkpoint
34
+ progress(0.1, desc="πŸ“₯ Downloading checkpoint...")
35
+ log("πŸ“₯ Downloading checkpoint from Hub...")
36
+
37
+ local_dir = snapshot_download(
38
+ repo_id=SOURCE_REPO,
39
+ local_dir="/tmp/checkpoint",
40
+ token=os.getenv("HF_TOKEN")
41
+ )
42
+ log(f" Downloaded to: {local_dir}")
43
+
44
+ # Step 2: Find adapter
45
+ progress(0.2, desc="πŸ” Finding adapter...")
46
+ adapter_path = None
47
+
48
+ # Check locations
49
+ for path in [f"{local_dir}/last-checkpoint", local_dir]:
50
+ if os.path.exists(f"{path}/adapter_config.json"):
51
+ adapter_path = path
52
+ log(f"βœ… Found adapter at: {path}")
53
+ break
54
+
55
+ if not adapter_path:
56
+ # List files for debugging
57
+ log("❌ adapter_config.json not found!")
58
+ log("πŸ“‚ Available files:")
59
+ for root, dirs, files in os.walk(local_dir):
60
+ for f in files:
61
+ rel_path = os.path.relpath(os.path.join(root, f), local_dir)
62
+ log(f" - {rel_path}")
63
+ return "\n".join(logs) + "\n\n❌ FAILED: No adapter found"
64
+
65
+ # Step 3: Load base model
66
+ progress(0.3, desc="πŸ“¦ Loading base model...")
67
+ log(f"πŸ“¦ Loading base model: {BASE_MODEL}")
68
+ log(" This may take 3-5 minutes...")
69
+
70
+ base_model = AutoModelForCausalLM.from_pretrained(
71
+ BASE_MODEL,
72
+ torch_dtype=torch.bfloat16,
73
+ device_map="auto",
74
+ trust_remote_code=True,
75
+ )
76
+ log(" βœ… Base model loaded!")
77
+
78
+ # Step 4: Load tokenizer
79
+ progress(0.4, desc="πŸ“ Loading tokenizer...")
80
+ log("πŸ“ Loading tokenizer...")
81
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
82
+ log(" βœ… Tokenizer loaded!")
83
+
84
+ # Step 5: Load LoRA adapter
85
+ progress(0.5, desc="πŸ”— Loading LoRA adapter...")
86
+ log(f"πŸ”— Loading LoRA adapter from: {adapter_path}")
87
+
88
+ model = PeftModel.from_pretrained(
89
+ base_model,
90
+ adapter_path,
91
+ torch_dtype=torch.bfloat16,
92
+ )
93
+ log(" βœ… LoRA adapter loaded!")
94
+
95
+ # Step 6: Merge
96
+ progress(0.6, desc="πŸ”§ Merging LoRA with base model...")
97
+ log("πŸ”§ Merging LoRA weights with base model...")
98
+ model = model.merge_and_unload()
99
+ log(" βœ… Merge complete!")
100
+
101
+ # Step 7: Save
102
+ progress(0.7, desc="πŸ’Ύ Saving merged model...")
103
+ output_dir = "/tmp/merged_model"
104
+ log(f"πŸ’Ύ Saving merged model to: {output_dir}")
105
+
106
+ os.makedirs(output_dir, exist_ok=True)
107
+ model.save_pretrained(output_dir, safe_serialization=True, max_shard_size="5GB")
108
+ tokenizer.save_pretrained(output_dir)
109
+
110
+ # List saved files
111
+ log(" πŸ“‚ Saved files:")
112
+ for f in os.listdir(output_dir):
113
+ size_mb = os.path.getsize(os.path.join(output_dir, f)) / (1024 * 1024)
114
+ log(f" - {f}: {size_mb:.1f} MB")
115
+
116
+ # Step 8: Create model card
117
+ progress(0.8, desc="πŸ“ Creating model card...")
118
+ log("πŸ“ Creating model card...")
119
+
120
+ model_card = """---
121
+ license: apache-2.0
122
+ base_model: Qwen/Qwen2.5-Math-7B-Instruct
123
+ tags:
124
+ - math
125
+ - ib-mathematics
126
+ - qwen2
127
+ - fine-tuned
128
+ - education
129
+ - ontology
130
+ - chain-of-thought
131
+ language:
132
+ - en
133
+ pipeline_tag: text-generation
134
+ ---
135
+
136
+ # IB-Math-Ontology-7B
137
+
138
+ Fine-tuned Qwen2.5-Math-7B-Instruct for IB Mathematics AA with ontology-based Chain-of-Thought reasoning.
139
+
140
+ ## Features
141
+ - 🎯 **IB Math AA Specialized**: Trained on 1,332 ontology-based examples
142
+ - πŸ’­ **Chain-of-Thought**: Uses `<think>` tags for step-by-step reasoning
143
+ - πŸ“š **Curriculum-Aligned**: Covers all 5 IB Math AA topics
144
+ - ⚠️ **Pitfall Awareness**: Warns about common student mistakes
145
+
146
+ ## Usage
147
+
148
+ ```python
149
+ from transformers import AutoModelForCausalLM, AutoTokenizer
150
+
151
+ model = AutoModelForCausalLM.from_pretrained("ongilLabs/IB-Math-Ontology-7B", torch_dtype="auto", device_map="auto")
152
+ tokenizer = AutoTokenizer.from_pretrained("ongilLabs/IB-Math-Ontology-7B")
153
+
154
+ prompt = "Find the derivative of f(x) = xΒ³ - 2xΒ² + 5x [6 marks]"
155
+ messages = [
156
+ {"role": "system", "content": "You are an expert IB Mathematics AA tutor. Think step-by-step and explain concepts clearly."},
157
+ {"role": "user", "content": prompt}
158
+ ]
159
+ text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
160
+ inputs = tokenizer(text, return_tensors="pt").to(model.device)
161
+ outputs = model.generate(**inputs, max_new_tokens=512)
162
+ print(tokenizer.decode(outputs[0], skip_special_tokens=True))
163
+ ```
164
+
165
+ ## Training Details
166
+ - **Base Model**: Qwen2.5-Math-7B-Instruct
167
+ - **Method**: LoRA (r=64, alpha=128)
168
+ - **Dataset**: 1,332 IB Math Ontology examples with CoT
169
+ - **Hardware**: NVIDIA A100 (80GB)
170
+ - **Epochs**: 3
171
+ - **Precision**: BF16
172
+ """
173
+
174
+ with open(os.path.join(output_dir, "README.md"), "w") as f:
175
+ f.write(model_card)
176
+ log(" βœ… Model card created!")
177
+
178
+ # Step 9: Upload to Hub
179
+ progress(0.9, desc="πŸš€ Uploading to Hub...")
180
+ log(f"πŸš€ Uploading to Hub: {OUTPUT_REPO}")
181
+
182
+ api = HfApi(token=os.getenv("HF_TOKEN"))
183
+ api.upload_folder(
184
+ folder_path=output_dir,
185
+ repo_id=OUTPUT_REPO,
186
+ commit_message="✨ Merged LoRA with base model - Production ready",
187
+ )
188
+
189
+ log(f" βœ… Uploaded to: https://huggingface.co/{OUTPUT_REPO}")
190
+
191
+ # Done!
192
+ progress(1.0, desc="πŸŽ‰ Complete!")
193
+ log("")
194
+ log("=" * 50)
195
+ log("πŸŽ‰ SUCCESS! Model merged and uploaded!")
196
+ log("=" * 50)
197
+ log(f"πŸ“ Model URL: https://huggingface.co/{OUTPUT_REPO}")
198
+
199
+ return "\n".join(logs)
200
+
201
+ except Exception as e:
202
+ log(f"\n❌ ERROR: {str(e)}")
203
+ import traceback
204
+ log(traceback.format_exc())
205
+ return "\n".join(logs)
206
+
207
+
208
+ def create_ui():
209
+ """Gradio UI 생성"""
210
+ with gr.Blocks(title="LoRA Merger", theme=gr.themes.Soft()) as app:
211
+ gr.Markdown("""
212
+ # πŸ”§ IB-Math-Ontology LoRA Merger
213
+
214
+ This Space merges the LoRA adapter with the base model.
215
+
216
+ **Source**: `ongilLabs/IB-Math-Ontology-7B` (LoRA adapter)
217
+ **Base**: `Qwen/Qwen2.5-Math-7B-Instruct`
218
+ **Output**: `ongilLabs/IB-Math-Ontology-7B` (merged model)
219
+
220
+ **Steps:**
221
+ 1. Download LoRA checkpoint from Hub
222
+ 2. Load base model (Qwen2.5-Math-7B-Instruct)
223
+ 3. Load LoRA adapter
224
+ 4. Merge LoRA weights into base model
225
+ 5. Upload merged model to Hub
226
+ """)
227
+
228
+ with gr.Row():
229
+ merge_btn = gr.Button("πŸš€ Start Merge", variant="primary", scale=2)
230
+
231
+ output = gr.Textbox(
232
+ label="Logs",
233
+ lines=30,
234
+ max_lines=50,
235
+ show_copy_button=True
236
+ )
237
+
238
+ merge_btn.click(fn=merge_model, outputs=output)
239
+
240
+ gr.Markdown("""
241
+ ---
242
+ **Note**: 이 μž‘μ—…μ€ μ•½ 10-15λΆ„ μ†Œμš”λ©λ‹ˆλ‹€. GPU λ©”λͺ¨λ¦¬κ°€ μΆ©λΆ„ν•œμ§€ ν™•μΈν•˜μ„Έμš”.
243
+ """)
244
+
245
+ return app
246
+
247
+
248
+ if __name__ == "__main__":
249
+ app = create_ui()
250
+ app.launch()
251
+