Spaces:
Build error
Build error
Commit ·
b813321
1
Parent(s): 9e4cab2
Changes
Browse files- .gitignore +1 -1
- README.md +43 -2
- app.py +0 -24
- backend/app.py +103 -0
- backend/initiate.py +0 -13
- requirements.txt +7 -2
- scripts/evaluate.py +69 -0
- scripts/train.py +91 -0
- src/__init__.py +16 -0
- src/config.py +67 -0
- src/model.py +92 -0
- src/tuning/__init__.py +15 -0
- src/tuning/data.py +70 -0
- src/tuning/lora.py +53 -0
- src/tuning/trainer.py +100 -0
.gitignore
CHANGED
|
@@ -1,2 +1,2 @@
|
|
| 1 |
**cache**
|
| 2 |
-
*.
|
|
|
|
| 1 |
**cache**
|
| 2 |
+
*.ipynb
|
README.md
CHANGED
|
@@ -2,8 +2,49 @@
|
|
| 2 |
|
| 3 |
A system for compressing long-form content into clear, structured summaries.
|
| 4 |
|
| 5 |
-
Précis is designed for articles, papers, and video transcripts. The goal is to
|
| 6 |
|
| 7 |
## Model
|
| 8 |
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
A system for compressing long-form content into clear, structured summaries.
|
| 4 |
|
| 5 |
+
Précis is designed for articles, papers, and video transcripts. The goal is to extract meaningful content rather than paraphrase main ideas.
|
| 6 |
|
| 7 |
## Model
|
| 8 |
|
| 9 |
+
Qwen-2.5-7B-Instruct with 4-bit quantization (BitsAndBytes NF4) for efficiency. Fine-tuned using LoRA for summarization.
|
| 10 |
+
|
| 11 |
+
## Installation
|
| 12 |
+
|
| 13 |
+
```bash
|
| 14 |
+
pip install -r requirements.txt
|
| 15 |
+
```
|
| 16 |
+
|
| 17 |
+
## Usage
|
| 18 |
+
|
| 19 |
+
### Training (with dummy data)
|
| 20 |
+
|
| 21 |
+
```bash
|
| 22 |
+
# Dry run to validate pipeline
|
| 23 |
+
python scripts/train.py --dry-run
|
| 24 |
+
|
| 25 |
+
# Full training
|
| 26 |
+
python scripts/train.py --epochs 3 --batch-size 4
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
### Evaluation
|
| 30 |
+
|
| 31 |
+
```bash
|
| 32 |
+
python scripts/evaluate.py --checkpoint ./outputs
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
## API
|
| 36 |
+
|
| 37 |
+
### Running the API
|
| 38 |
+
|
| 39 |
+
```bash
|
| 40 |
+
python app.py
|
| 41 |
+
# or
|
| 42 |
+
uvicorn app:app --reload
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
### Endpoints
|
| 46 |
+
|
| 47 |
+
- `GET /` — API documentation page
|
| 48 |
+
- `GET /health` — Health check
|
| 49 |
+
- `GET /status` — Service status and model info
|
| 50 |
+
- `POST /summarize` — Summarize content from URL (currently returns dummy data)
|
app.py
DELETED
|
@@ -1,24 +0,0 @@
|
|
| 1 |
-
from fastapi import FastAPI
|
| 2 |
-
from fastapi.responses import HTMLResponse
|
| 3 |
-
|
| 4 |
-
app = FastAPI(title="Précis — MVP")
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
@app.get("/", response_class=HTMLResponse)
|
| 8 |
-
async def root():
|
| 9 |
-
return """
|
| 10 |
-
<html>
|
| 11 |
-
<head>
|
| 12 |
-
<title>Précis — MVP</title>
|
| 13 |
-
</head>
|
| 14 |
-
<body>
|
| 15 |
-
<h1>Précis — MVP</h1>
|
| 16 |
-
<p>Welcome to Précis</p>
|
| 17 |
-
</body>
|
| 18 |
-
</html>
|
| 19 |
-
"""
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
if __name__ == "__main__":
|
| 23 |
-
import uvicorn
|
| 24 |
-
uvicorn.run(app, host="0.0.0.0", port=8000)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
backend/app.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI backend for Précis."""
|
| 2 |
+
|
| 3 |
+
from fastapi import FastAPI, HTTPException
|
| 4 |
+
from fastapi.responses import HTMLResponse
|
| 5 |
+
from pydantic import BaseModel
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
app = FastAPI(
|
| 9 |
+
title="Précis API",
|
| 10 |
+
description="Content summarization API",
|
| 11 |
+
version="0.1.0"
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class SummarizeRequest(BaseModel):
|
| 16 |
+
"""Request model for summarization."""
|
| 17 |
+
url: str
|
| 18 |
+
max_length: Optional[int] = 512
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class SummarizeResponse(BaseModel):
|
| 22 |
+
"""Response model for summarization."""
|
| 23 |
+
url: str
|
| 24 |
+
summary: str
|
| 25 |
+
success: bool
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@app.get("/", response_class=HTMLResponse)
|
| 29 |
+
async def root():
|
| 30 |
+
"""Root endpoint with basic info."""
|
| 31 |
+
return """
|
| 32 |
+
<!DOCTYPE html>
|
| 33 |
+
<html>
|
| 34 |
+
<head>
|
| 35 |
+
<title>Précis API</title>
|
| 36 |
+
<style>
|
| 37 |
+
body { font-family: system-ui; max-width: 800px; margin: 50px auto; padding: 20px; }
|
| 38 |
+
h1 { color: #333; }
|
| 39 |
+
code { background: #f4f4f4; padding: 2px 6px; border-radius: 3px; }
|
| 40 |
+
</style>
|
| 41 |
+
</head>
|
| 42 |
+
<body>
|
| 43 |
+
<h1>Précis API</h1>
|
| 44 |
+
<p>Content summarization service</p>
|
| 45 |
+
<h2>Endpoints</h2>
|
| 46 |
+
<ul>
|
| 47 |
+
<li><code>POST /summarize</code> - Summarize content from URL</li>
|
| 48 |
+
<li><code>GET /health</code> - Health check</li>
|
| 49 |
+
<li><code>GET /status</code> - Service status</li>
|
| 50 |
+
<li><code>GET /docs</code> - API documentation</li>
|
| 51 |
+
</ul>
|
| 52 |
+
</body>
|
| 53 |
+
</html>
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@app.get("/health")
|
| 58 |
+
async def health():
|
| 59 |
+
"""Health check endpoint."""
|
| 60 |
+
return {"status": "healthy", "service": "precis"}
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@app.get("/status")
|
| 64 |
+
async def status():
|
| 65 |
+
"""Service status endpoint."""
|
| 66 |
+
return {
|
| 67 |
+
"service": "Précis API",
|
| 68 |
+
"version": "0.1.0",
|
| 69 |
+
"model": "Qwen/Qwen2.5-7B-Instruct",
|
| 70 |
+
"model_loaded": False, # TODO: Track actual model state
|
| 71 |
+
"endpoints": ["/", "/health", "/status", "/summarize"]
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@app.post("/summarize", response_model=SummarizeResponse)
|
| 76 |
+
async def summarize(request: SummarizeRequest):
|
| 77 |
+
"""
|
| 78 |
+
Summarize content from a URL.
|
| 79 |
+
|
| 80 |
+
Currently returns dummy data. Will be implemented with actual model.
|
| 81 |
+
"""
|
| 82 |
+
# TODO: Implement actual summarization
|
| 83 |
+
# 1. Fetch content from URL
|
| 84 |
+
# 2. Parse text (YouTube transcript or article)
|
| 85 |
+
# 3. Run through model
|
| 86 |
+
# 4. Return summary
|
| 87 |
+
|
| 88 |
+
dummy_summary = (
|
| 89 |
+
f"This is a placeholder summary for content at {request.url}. "
|
| 90 |
+
"The actual summarization model will be integrated in the next phase. "
|
| 91 |
+
"This summary respects the max_length parameter of {request.max_length} tokens."
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
return SummarizeResponse(
|
| 95 |
+
url=request.url,
|
| 96 |
+
summary=dummy_summary,
|
| 97 |
+
success=True
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
if __name__ == "__main__":
|
| 102 |
+
import uvicorn
|
| 103 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
backend/initiate.py
DELETED
|
@@ -1,13 +0,0 @@
|
|
| 1 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 2 |
-
|
| 3 |
-
MODEL = "Qwen/Qwen2.5-7B-Instruct.gguf.q5_0"
|
| 4 |
-
|
| 5 |
-
tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True)
|
| 6 |
-
|
| 7 |
-
model = AutoModelForCausalLM.from_pretrained(
|
| 8 |
-
MODEL,
|
| 9 |
-
device_map="auto",
|
| 10 |
-
load_in_4bit=True,
|
| 11 |
-
torch_dtype="auto",
|
| 12 |
-
trust_remote_code=True
|
| 13 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
|
@@ -1,7 +1,12 @@
|
|
| 1 |
-
|
|
|
|
| 2 |
transformers
|
| 3 |
accelerate
|
| 4 |
bitsandbytes
|
| 5 |
-
|
|
|
|
| 6 |
sentencepiece
|
|
|
|
|
|
|
| 7 |
fastapi
|
|
|
|
|
|
| 1 |
+
# Core ML
|
| 2 |
+
torch
|
| 3 |
transformers
|
| 4 |
accelerate
|
| 5 |
bitsandbytes
|
| 6 |
+
peft
|
| 7 |
+
datasets
|
| 8 |
sentencepiece
|
| 9 |
+
|
| 10 |
+
# API
|
| 11 |
fastapi
|
| 12 |
+
uvicorn
|
scripts/evaluate.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""CLI evaluation script for Précis."""
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import logging
|
| 6 |
+
import sys
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 10 |
+
|
| 11 |
+
from src.config import ModelConfig, DataConfig
|
| 12 |
+
from src.model import load_tokenizer
|
| 13 |
+
from src.tuning.data import create_dummy_data
|
| 14 |
+
|
| 15 |
+
from transformers import AutoModelForCausalLM
|
| 16 |
+
from peft import PeftModel
|
| 17 |
+
|
| 18 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def parse_args():
|
| 23 |
+
parser = argparse.ArgumentParser(description="Evaluate Précis model")
|
| 24 |
+
parser.add_argument("--checkpoint", type=str, required=True, help="Path to model checkpoint")
|
| 25 |
+
parser.add_argument("--num-samples", type=int, default=5, help="Number of samples to evaluate")
|
| 26 |
+
parser.add_argument("--max-new-tokens", type=int, default=256, help="Max tokens to generate")
|
| 27 |
+
return parser.parse_args()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def main():
|
| 31 |
+
args = parse_args()
|
| 32 |
+
config = ModelConfig()
|
| 33 |
+
data_config = DataConfig()
|
| 34 |
+
|
| 35 |
+
logger.info(f"Loading checkpoint from {args.checkpoint}")
|
| 36 |
+
tokenizer = load_tokenizer(config)
|
| 37 |
+
|
| 38 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 39 |
+
args.checkpoint,
|
| 40 |
+
device_map="auto",
|
| 41 |
+
trust_remote_code=True,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
# Generate on dummy samples
|
| 45 |
+
samples = create_dummy_data(args.num_samples)
|
| 46 |
+
|
| 47 |
+
for i, sample in enumerate(samples):
|
| 48 |
+
prompt = data_config.format_prompt(sample["text"])
|
| 49 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 50 |
+
|
| 51 |
+
outputs = model.generate(
|
| 52 |
+
**inputs,
|
| 53 |
+
max_new_tokens=args.max_new_tokens,
|
| 54 |
+
do_sample=True,
|
| 55 |
+
temperature=0.7,
|
| 56 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 60 |
+
summary = generated[len(prompt):]
|
| 61 |
+
|
| 62 |
+
logger.info(f"\n=== Sample {i+1} ===")
|
| 63 |
+
logger.info(f"Input: {sample['text'][:100]}...")
|
| 64 |
+
logger.info(f"Generated: {summary}")
|
| 65 |
+
logger.info(f"Reference: {sample['summary']}")
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
if __name__ == "__main__":
|
| 69 |
+
main()
|
scripts/train.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""CLI training script for Précis."""
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import logging
|
| 6 |
+
import sys
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
# Add project root to path
|
| 10 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 11 |
+
|
| 12 |
+
from src.config import ModelConfig, TrainingConfig, DataConfig
|
| 13 |
+
from src.model import load_model, load_tokenizer, prepare_for_training
|
| 14 |
+
from src.tuning.lora import apply_lora
|
| 15 |
+
from src.tuning.data import create_dummy_data, prepare_dataset
|
| 16 |
+
from src.tuning.trainer import PrecisTrainer
|
| 17 |
+
|
| 18 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def parse_args():
|
| 23 |
+
parser = argparse.ArgumentParser(description="Train Précis summarization model")
|
| 24 |
+
parser.add_argument("--model-id", type=str, default=None, help="HuggingFace model ID")
|
| 25 |
+
parser.add_argument("--output-dir", type=str, default="./outputs", help="Output directory")
|
| 26 |
+
parser.add_argument("--epochs", type=int, default=3, help="Number of training epochs")
|
| 27 |
+
parser.add_argument("--batch-size", type=int, default=4, help="Batch size")
|
| 28 |
+
parser.add_argument("--learning-rate", type=float, default=2e-4, help="Learning rate")
|
| 29 |
+
parser.add_argument("--lora-r", type=int, default=16, help="LoRA rank")
|
| 30 |
+
parser.add_argument("--dry-run", action="store_true", help="Validate pipeline without training")
|
| 31 |
+
parser.add_argument("--dummy-samples", type=int, default=100, help="Number of dummy samples")
|
| 32 |
+
return parser.parse_args()
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def main():
|
| 36 |
+
args = parse_args()
|
| 37 |
+
|
| 38 |
+
# Build configs
|
| 39 |
+
model_config = ModelConfig()
|
| 40 |
+
if args.model_id:
|
| 41 |
+
model_config.model_id = args.model_id
|
| 42 |
+
|
| 43 |
+
training_config = TrainingConfig(
|
| 44 |
+
output_dir=args.output_dir,
|
| 45 |
+
num_epochs=args.epochs,
|
| 46 |
+
batch_size=args.batch_size,
|
| 47 |
+
learning_rate=args.learning_rate,
|
| 48 |
+
lora_r=args.lora_r,
|
| 49 |
+
)
|
| 50 |
+
data_config = DataConfig()
|
| 51 |
+
|
| 52 |
+
if args.dry_run:
|
| 53 |
+
logger.info("=== DRY RUN MODE ===")
|
| 54 |
+
logger.info(f"Model: {model_config.model_id}")
|
| 55 |
+
logger.info(f"Output: {training_config.output_dir}")
|
| 56 |
+
logger.info(f"Epochs: {training_config.num_epochs}, Batch: {training_config.batch_size}")
|
| 57 |
+
logger.info(f"LoRA r: {training_config.lora_r}, alpha: {training_config.lora_alpha}")
|
| 58 |
+
|
| 59 |
+
# Test data pipeline only
|
| 60 |
+
dummy_data = create_dummy_data(5)
|
| 61 |
+
logger.info(f"Dummy data sample: {dummy_data[0]}")
|
| 62 |
+
logger.info("Dry run complete. Pipeline validated.")
|
| 63 |
+
return
|
| 64 |
+
|
| 65 |
+
# Load model and tokenizer
|
| 66 |
+
logger.info("Loading model and tokenizer...")
|
| 67 |
+
tokenizer = load_tokenizer(model_config)
|
| 68 |
+
model = load_model(model_config)
|
| 69 |
+
model = prepare_for_training(model)
|
| 70 |
+
model = apply_lora(model, training_config)
|
| 71 |
+
|
| 72 |
+
# Prepare data
|
| 73 |
+
logger.info("Preparing training data...")
|
| 74 |
+
train_data = create_dummy_data(args.dummy_samples)
|
| 75 |
+
train_dataset = prepare_dataset(train_data, tokenizer, data_config)
|
| 76 |
+
|
| 77 |
+
# Train
|
| 78 |
+
trainer = PrecisTrainer(
|
| 79 |
+
model=model,
|
| 80 |
+
tokenizer=tokenizer,
|
| 81 |
+
train_dataset=train_dataset,
|
| 82 |
+
config=training_config,
|
| 83 |
+
)
|
| 84 |
+
trainer.train()
|
| 85 |
+
trainer.save()
|
| 86 |
+
|
| 87 |
+
logger.info("Training complete!")
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
if __name__ == "__main__":
|
| 91 |
+
main()
|
src/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Précis — Model loading, configuration, and fine-tuning utilities.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from src.config import ModelConfig, TrainingConfig, DataConfig
|
| 6 |
+
from src.model import load_model, load_tokenizer, prepare_for_training
|
| 7 |
+
|
| 8 |
+
__version__ = "0.1.0"
|
| 9 |
+
__all__ = [
|
| 10 |
+
"ModelConfig",
|
| 11 |
+
"TrainingConfig",
|
| 12 |
+
"DataConfig",
|
| 13 |
+
"load_model",
|
| 14 |
+
"load_tokenizer",
|
| 15 |
+
"prepare_for_training",
|
| 16 |
+
]
|
src/config.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Configuration management for Précis."""
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
from typing import Optional, List
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@dataclass
|
| 8 |
+
class ModelConfig:
|
| 9 |
+
"""Configuration for model loading and quantization."""
|
| 10 |
+
model_id: str = "Qwen/Qwen2.5-7B-Instruct"
|
| 11 |
+
load_in_4bit: bool = True
|
| 12 |
+
load_in_8bit: bool = False
|
| 13 |
+
bnb_4bit_compute_dtype: str = "float16"
|
| 14 |
+
bnb_4bit_quant_type: str = "nf4"
|
| 15 |
+
bnb_4bit_use_double_quant: bool = True
|
| 16 |
+
device_map: str = "auto"
|
| 17 |
+
trust_remote_code: bool = True
|
| 18 |
+
cache_dir: Optional[str] = None
|
| 19 |
+
|
| 20 |
+
def __post_init__(self):
|
| 21 |
+
if self.load_in_4bit and self.load_in_8bit:
|
| 22 |
+
raise ValueError("Cannot enable both 4-bit and 8-bit quantization")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class TrainingConfig:
|
| 27 |
+
"""Configuration for LoRA fine-tuning."""
|
| 28 |
+
lora_r: int = 16
|
| 29 |
+
lora_alpha: int = 32
|
| 30 |
+
lora_dropout: float = 0.05
|
| 31 |
+
lora_target_modules: List[str] = field(
|
| 32 |
+
default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"]
|
| 33 |
+
)
|
| 34 |
+
learning_rate: float = 2e-4
|
| 35 |
+
batch_size: int = 4
|
| 36 |
+
gradient_accumulation_steps: int = 4
|
| 37 |
+
num_epochs: int = 3
|
| 38 |
+
warmup_ratio: float = 0.03
|
| 39 |
+
weight_decay: float = 0.01
|
| 40 |
+
max_grad_norm: float = 1.0
|
| 41 |
+
max_seq_length: int = 2048
|
| 42 |
+
optim: str = "paged_adamw_32bit"
|
| 43 |
+
save_steps: int = 100
|
| 44 |
+
logging_steps: int = 10
|
| 45 |
+
eval_steps: int = 100
|
| 46 |
+
output_dir: str = "./outputs"
|
| 47 |
+
seed: int = 42
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@dataclass
|
| 51 |
+
class DataConfig:
|
| 52 |
+
"""Configuration for dataset loading and preprocessing."""
|
| 53 |
+
train_file: Optional[str] = None
|
| 54 |
+
eval_file: Optional[str] = None
|
| 55 |
+
input_column: str = "text"
|
| 56 |
+
target_column: str = "summary"
|
| 57 |
+
max_input_length: int = 1536
|
| 58 |
+
max_target_length: int = 512
|
| 59 |
+
train_split: float = 0.9
|
| 60 |
+
prompt_template: str = (
|
| 61 |
+
"Summarize the following document:\n\n"
|
| 62 |
+
"### Document:\n{input}\n\n"
|
| 63 |
+
"### Summary:\n"
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
def format_prompt(self, text: str) -> str:
|
| 67 |
+
return self.prompt_template.format(input=text)
|
src/model.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Model loading utilities for Précis."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
from typing import Optional, Tuple
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from transformers import (
|
| 8 |
+
AutoModelForCausalLM,
|
| 9 |
+
AutoTokenizer,
|
| 10 |
+
BitsAndBytesConfig,
|
| 11 |
+
PreTrainedModel,
|
| 12 |
+
PreTrainedTokenizer,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
from src.config import ModelConfig
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def get_quantization_config(config: ModelConfig) -> Optional[BitsAndBytesConfig]:
|
| 21 |
+
"""Create BitsAndBytes quantization configuration."""
|
| 22 |
+
if config.load_in_4bit:
|
| 23 |
+
compute_dtype = getattr(torch, config.bnb_4bit_compute_dtype)
|
| 24 |
+
return BitsAndBytesConfig(
|
| 25 |
+
load_in_4bit=True,
|
| 26 |
+
bnb_4bit_compute_dtype=compute_dtype,
|
| 27 |
+
bnb_4bit_quant_type=config.bnb_4bit_quant_type,
|
| 28 |
+
bnb_4bit_use_double_quant=config.bnb_4bit_use_double_quant,
|
| 29 |
+
)
|
| 30 |
+
elif config.load_in_8bit:
|
| 31 |
+
return BitsAndBytesConfig(load_in_8bit=True)
|
| 32 |
+
return None
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def load_tokenizer(config: Optional[ModelConfig] = None) -> PreTrainedTokenizer:
|
| 36 |
+
"""Load and configure the tokenizer."""
|
| 37 |
+
if config is None:
|
| 38 |
+
config = ModelConfig()
|
| 39 |
+
|
| 40 |
+
logger.info(f"Loading tokenizer: {config.model_id}")
|
| 41 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 42 |
+
config.model_id,
|
| 43 |
+
trust_remote_code=config.trust_remote_code,
|
| 44 |
+
cache_dir=config.cache_dir,
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
if tokenizer.pad_token is None:
|
| 48 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 49 |
+
tokenizer.pad_token_id = tokenizer.eos_token_id
|
| 50 |
+
tokenizer.padding_side = "right"
|
| 51 |
+
|
| 52 |
+
return tokenizer
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def load_model(config: Optional[ModelConfig] = None) -> PreTrainedModel:
|
| 56 |
+
"""Load the base model with optional quantization."""
|
| 57 |
+
if config is None:
|
| 58 |
+
config = ModelConfig()
|
| 59 |
+
|
| 60 |
+
logger.info(f"Loading model: {config.model_id}")
|
| 61 |
+
quantization_config = get_quantization_config(config)
|
| 62 |
+
|
| 63 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 64 |
+
config.model_id,
|
| 65 |
+
quantization_config=quantization_config,
|
| 66 |
+
device_map=config.device_map,
|
| 67 |
+
trust_remote_code=config.trust_remote_code,
|
| 68 |
+
cache_dir=config.cache_dir,
|
| 69 |
+
torch_dtype=torch.float16 if quantization_config else "auto",
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
logger.info(f"Model loaded. Parameters: {model.num_parameters():,}")
|
| 73 |
+
return model
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def prepare_for_training(model: PreTrainedModel, gradient_checkpointing: bool = True) -> PreTrainedModel:
|
| 77 |
+
"""Prepare model for training with gradient checkpointing and k-bit setup."""
|
| 78 |
+
if gradient_checkpointing:
|
| 79 |
+
model.gradient_checkpointing_enable()
|
| 80 |
+
|
| 81 |
+
if getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False):
|
| 82 |
+
from peft import prepare_model_for_kbit_training
|
| 83 |
+
model = prepare_model_for_kbit_training(model)
|
| 84 |
+
|
| 85 |
+
return model
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def load_model_and_tokenizer(config: Optional[ModelConfig] = None) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
| 89 |
+
"""Load both model and tokenizer."""
|
| 90 |
+
if config is None:
|
| 91 |
+
config = ModelConfig()
|
| 92 |
+
return load_model(config), load_tokenizer(config)
|
src/tuning/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tuning subpackage for Précis."""
|
| 2 |
+
|
| 3 |
+
from src.tuning.lora import get_lora_config, apply_lora, merge_and_save
|
| 4 |
+
from src.tuning.data import SummarizationDataset, prepare_dataset, create_dummy_data
|
| 5 |
+
from src.tuning.trainer import PrecisTrainer
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
"get_lora_config",
|
| 9 |
+
"apply_lora",
|
| 10 |
+
"merge_and_save",
|
| 11 |
+
"SummarizationDataset",
|
| 12 |
+
"prepare_dataset",
|
| 13 |
+
"create_dummy_data",
|
| 14 |
+
"PrecisTrainer",
|
| 15 |
+
]
|
src/tuning/data.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Data preparation utilities for training."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
from typing import Dict, List, Optional, Any
|
| 5 |
+
|
| 6 |
+
from torch.utils.data import Dataset
|
| 7 |
+
from transformers import PreTrainedTokenizer
|
| 8 |
+
|
| 9 |
+
from src.config import DataConfig
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class SummarizationDataset(Dataset):
|
| 15 |
+
"""PyTorch Dataset for summarization training."""
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
data: List[Dict[str, str]],
|
| 20 |
+
tokenizer: PreTrainedTokenizer,
|
| 21 |
+
config: Optional[DataConfig] = None,
|
| 22 |
+
):
|
| 23 |
+
self.data = data
|
| 24 |
+
self.tokenizer = tokenizer
|
| 25 |
+
self.config = config or DataConfig()
|
| 26 |
+
|
| 27 |
+
def __len__(self) -> int:
|
| 28 |
+
return len(self.data)
|
| 29 |
+
|
| 30 |
+
def __getitem__(self, idx: int) -> Dict[str, Any]:
|
| 31 |
+
item = self.data[idx]
|
| 32 |
+
prompt = self.config.format_prompt(item[self.config.input_column])
|
| 33 |
+
full_text = prompt + item[self.config.target_column] + self.tokenizer.eos_token
|
| 34 |
+
|
| 35 |
+
encoding = self.tokenizer(
|
| 36 |
+
full_text,
|
| 37 |
+
truncation=True,
|
| 38 |
+
max_length=self.config.max_input_length + self.config.max_target_length,
|
| 39 |
+
padding="max_length",
|
| 40 |
+
return_tensors="pt",
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
return {
|
| 44 |
+
"input_ids": encoding["input_ids"].squeeze(),
|
| 45 |
+
"attention_mask": encoding["attention_mask"].squeeze(),
|
| 46 |
+
"labels": encoding["input_ids"].squeeze(),
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def create_dummy_data(num_samples: int = 10) -> List[Dict[str, str]]:
|
| 51 |
+
"""Generate dummy data for testing the training pipeline."""
|
| 52 |
+
samples = []
|
| 53 |
+
for i in range(num_samples):
|
| 54 |
+
samples.append({
|
| 55 |
+
"text": f"This is sample document {i}. It contains information about topic {i % 3}. "
|
| 56 |
+
f"The document discusses various aspects and provides detailed analysis. "
|
| 57 |
+
f"Key points include methodology, results, and conclusions for study {i}.",
|
| 58 |
+
"summary": f"Document {i} analyzes topic {i % 3}, covering methodology, results, and conclusions.",
|
| 59 |
+
})
|
| 60 |
+
logger.info(f"Created {num_samples} dummy samples")
|
| 61 |
+
return samples
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def prepare_dataset(
|
| 65 |
+
data: List[Dict[str, str]],
|
| 66 |
+
tokenizer: PreTrainedTokenizer,
|
| 67 |
+
config: Optional[DataConfig] = None,
|
| 68 |
+
) -> SummarizationDataset:
|
| 69 |
+
"""Prepare dataset for training."""
|
| 70 |
+
return SummarizationDataset(data, tokenizer, config)
|
src/tuning/lora.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""LoRA/PEFT configuration and utilities."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
from peft import LoraConfig, get_peft_model, PeftModel, TaskType
|
| 8 |
+
from transformers import PreTrainedModel
|
| 9 |
+
|
| 10 |
+
from src.config import TrainingConfig
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def get_lora_config(config: Optional[TrainingConfig] = None) -> LoraConfig:
|
| 16 |
+
"""Create LoRA configuration for summarization task."""
|
| 17 |
+
if config is None:
|
| 18 |
+
config = TrainingConfig()
|
| 19 |
+
|
| 20 |
+
return LoraConfig(
|
| 21 |
+
r=config.lora_r,
|
| 22 |
+
lora_alpha=config.lora_alpha,
|
| 23 |
+
lora_dropout=config.lora_dropout,
|
| 24 |
+
target_modules=config.lora_target_modules,
|
| 25 |
+
bias="none",
|
| 26 |
+
task_type=TaskType.CAUSAL_LM,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def apply_lora(model: PreTrainedModel, config: Optional[TrainingConfig] = None) -> PeftModel:
|
| 31 |
+
"""Apply LoRA adapters to model."""
|
| 32 |
+
lora_config = get_lora_config(config)
|
| 33 |
+
logger.info(f"Applying LoRA with r={lora_config.r}, alpha={lora_config.lora_alpha}")
|
| 34 |
+
|
| 35 |
+
peft_model = get_peft_model(model, lora_config)
|
| 36 |
+
peft_model.print_trainable_parameters()
|
| 37 |
+
|
| 38 |
+
return peft_model
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def merge_and_save(model: PeftModel, output_path: str, tokenizer=None) -> None:
|
| 42 |
+
"""Merge LoRA weights into base model and save."""
|
| 43 |
+
output_dir = Path(output_path)
|
| 44 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 45 |
+
|
| 46 |
+
logger.info("Merging LoRA weights...")
|
| 47 |
+
merged_model = model.merge_and_unload()
|
| 48 |
+
|
| 49 |
+
logger.info(f"Saving merged model to {output_dir}")
|
| 50 |
+
merged_model.save_pretrained(output_dir)
|
| 51 |
+
|
| 52 |
+
if tokenizer:
|
| 53 |
+
tokenizer.save_pretrained(output_dir)
|
src/tuning/trainer.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Training orchestration for Précis."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
from transformers import (
|
| 8 |
+
Trainer,
|
| 9 |
+
TrainingArguments,
|
| 10 |
+
PreTrainedModel,
|
| 11 |
+
PreTrainedTokenizer,
|
| 12 |
+
DataCollatorForLanguageModeling,
|
| 13 |
+
)
|
| 14 |
+
from torch.utils.data import Dataset
|
| 15 |
+
|
| 16 |
+
from src.config import TrainingConfig
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class PrecisTrainer:
|
| 22 |
+
"""Wrapper around HuggingFace Trainer for summarization fine-tuning."""
|
| 23 |
+
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
model: PreTrainedModel,
|
| 27 |
+
tokenizer: PreTrainedTokenizer,
|
| 28 |
+
train_dataset: Dataset,
|
| 29 |
+
eval_dataset: Optional[Dataset] = None,
|
| 30 |
+
config: Optional[TrainingConfig] = None,
|
| 31 |
+
):
|
| 32 |
+
self.model = model
|
| 33 |
+
self.tokenizer = tokenizer
|
| 34 |
+
self.train_dataset = train_dataset
|
| 35 |
+
self.eval_dataset = eval_dataset
|
| 36 |
+
self.config = config or TrainingConfig()
|
| 37 |
+
|
| 38 |
+
self.training_args = self._create_training_args()
|
| 39 |
+
self.trainer = self._create_trainer()
|
| 40 |
+
|
| 41 |
+
def _create_training_args(self) -> TrainingArguments:
|
| 42 |
+
"""Create HuggingFace TrainingArguments from config."""
|
| 43 |
+
return TrainingArguments(
|
| 44 |
+
output_dir=self.config.output_dir,
|
| 45 |
+
num_train_epochs=self.config.num_epochs,
|
| 46 |
+
per_device_train_batch_size=self.config.batch_size,
|
| 47 |
+
gradient_accumulation_steps=self.config.gradient_accumulation_steps,
|
| 48 |
+
learning_rate=self.config.learning_rate,
|
| 49 |
+
warmup_ratio=self.config.warmup_ratio,
|
| 50 |
+
weight_decay=self.config.weight_decay,
|
| 51 |
+
max_grad_norm=self.config.max_grad_norm,
|
| 52 |
+
optim=self.config.optim,
|
| 53 |
+
logging_steps=self.config.logging_steps,
|
| 54 |
+
save_steps=self.config.save_steps,
|
| 55 |
+
eval_steps=self.config.eval_steps if self.eval_dataset else None,
|
| 56 |
+
evaluation_strategy="steps" if self.eval_dataset else "no",
|
| 57 |
+
save_total_limit=3,
|
| 58 |
+
load_best_model_at_end=bool(self.eval_dataset),
|
| 59 |
+
seed=self.config.seed,
|
| 60 |
+
fp16=True,
|
| 61 |
+
report_to="none",
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
def _create_trainer(self) -> Trainer:
|
| 65 |
+
"""Create HuggingFace Trainer instance."""
|
| 66 |
+
data_collator = DataCollatorForLanguageModeling(
|
| 67 |
+
tokenizer=self.tokenizer,
|
| 68 |
+
mlm=False,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
return Trainer(
|
| 72 |
+
model=self.model,
|
| 73 |
+
args=self.training_args,
|
| 74 |
+
train_dataset=self.train_dataset,
|
| 75 |
+
eval_dataset=self.eval_dataset,
|
| 76 |
+
data_collator=data_collator,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
def train(self) -> None:
|
| 80 |
+
"""Execute training loop."""
|
| 81 |
+
logger.info("Starting training...")
|
| 82 |
+
self.trainer.train()
|
| 83 |
+
logger.info("Training complete.")
|
| 84 |
+
|
| 85 |
+
def evaluate(self) -> dict:
|
| 86 |
+
"""Run evaluation and return metrics."""
|
| 87 |
+
if self.eval_dataset is None:
|
| 88 |
+
logger.warning("No eval dataset provided")
|
| 89 |
+
return {}
|
| 90 |
+
|
| 91 |
+
logger.info("Running evaluation...")
|
| 92 |
+
return self.trainer.evaluate()
|
| 93 |
+
|
| 94 |
+
def save(self, output_path: Optional[str] = None) -> None:
|
| 95 |
+
"""Save model checkpoint."""
|
| 96 |
+
path = output_path or self.config.output_dir
|
| 97 |
+
Path(path).mkdir(parents=True, exist_ok=True)
|
| 98 |
+
self.trainer.save_model(path)
|
| 99 |
+
self.tokenizer.save_pretrained(path)
|
| 100 |
+
logger.info(f"Model saved to {path}")
|