compendious commited on
Commit
b813321
·
1 Parent(s): 9e4cab2
.gitignore CHANGED
@@ -1,2 +1,2 @@
1
  **cache**
2
- *.ipyn
 
1
  **cache**
2
+ *.ipynb
README.md CHANGED
@@ -2,8 +2,49 @@
2
 
3
  A system for compressing long-form content into clear, structured summaries.
4
 
5
- Précis is designed for articles, papers, and video transcripts. The goal is to be able to extract the meaningful content rather than paraphrase the main ideas.
6
 
7
  ## Model
8
 
9
- The model used is Qwen-2.5-7B-Instruct with 5-bit quantization for efficiency. It's functional for specifically fine-tuning to fit a schema.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  A system for compressing long-form content into clear, structured summaries.
4
 
5
+ Précis is designed for articles, papers, and video transcripts. The goal is to extract meaningful content rather than paraphrase main ideas.
6
 
7
  ## Model
8
 
9
+ Qwen-2.5-7B-Instruct with 4-bit quantization (BitsAndBytes NF4) for efficiency. Fine-tuned using LoRA for summarization.
10
+
11
+ ## Installation
12
+
13
+ ```bash
14
+ pip install -r requirements.txt
15
+ ```
16
+
17
+ ## Usage
18
+
19
+ ### Training (with dummy data)
20
+
21
+ ```bash
22
+ # Dry run to validate pipeline
23
+ python scripts/train.py --dry-run
24
+
25
+ # Full training
26
+ python scripts/train.py --epochs 3 --batch-size 4
27
+ ```
28
+
29
+ ### Evaluation
30
+
31
+ ```bash
32
+ python scripts/evaluate.py --checkpoint ./outputs
33
+ ```
34
+
35
+ ## API
36
+
37
+ ### Running the API
38
+
39
+ ```bash
40
+ python app.py
41
+ # or
42
+ uvicorn app:app --reload
43
+ ```
44
+
45
+ ### Endpoints
46
+
47
+ - `GET /` — API documentation page
48
+ - `GET /health` — Health check
49
+ - `GET /status` — Service status and model info
50
+ - `POST /summarize` — Summarize content from URL (currently returns dummy data)
app.py DELETED
@@ -1,24 +0,0 @@
1
- from fastapi import FastAPI
2
- from fastapi.responses import HTMLResponse
3
-
4
- app = FastAPI(title="Précis — MVP")
5
-
6
-
7
- @app.get("/", response_class=HTMLResponse)
8
- async def root():
9
- return """
10
- <html>
11
- <head>
12
- <title>Précis — MVP</title>
13
- </head>
14
- <body>
15
- <h1>Précis — MVP</h1>
16
- <p>Welcome to Précis</p>
17
- </body>
18
- </html>
19
- """
20
-
21
-
22
- if __name__ == "__main__":
23
- import uvicorn
24
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
backend/app.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI backend for Précis."""
2
+
3
+ from fastapi import FastAPI, HTTPException
4
+ from fastapi.responses import HTMLResponse
5
+ from pydantic import BaseModel
6
+ from typing import Optional
7
+
8
+ app = FastAPI(
9
+ title="Précis API",
10
+ description="Content summarization API",
11
+ version="0.1.0"
12
+ )
13
+
14
+
15
+ class SummarizeRequest(BaseModel):
16
+ """Request model for summarization."""
17
+ url: str
18
+ max_length: Optional[int] = 512
19
+
20
+
21
+ class SummarizeResponse(BaseModel):
22
+ """Response model for summarization."""
23
+ url: str
24
+ summary: str
25
+ success: bool
26
+
27
+
28
+ @app.get("/", response_class=HTMLResponse)
29
+ async def root():
30
+ """Root endpoint with basic info."""
31
+ return """
32
+ <!DOCTYPE html>
33
+ <html>
34
+ <head>
35
+ <title>Précis API</title>
36
+ <style>
37
+ body { font-family: system-ui; max-width: 800px; margin: 50px auto; padding: 20px; }
38
+ h1 { color: #333; }
39
+ code { background: #f4f4f4; padding: 2px 6px; border-radius: 3px; }
40
+ </style>
41
+ </head>
42
+ <body>
43
+ <h1>Précis API</h1>
44
+ <p>Content summarization service</p>
45
+ <h2>Endpoints</h2>
46
+ <ul>
47
+ <li><code>POST /summarize</code> - Summarize content from URL</li>
48
+ <li><code>GET /health</code> - Health check</li>
49
+ <li><code>GET /status</code> - Service status</li>
50
+ <li><code>GET /docs</code> - API documentation</li>
51
+ </ul>
52
+ </body>
53
+ </html>
54
+ """
55
+
56
+
57
+ @app.get("/health")
58
+ async def health():
59
+ """Health check endpoint."""
60
+ return {"status": "healthy", "service": "precis"}
61
+
62
+
63
+ @app.get("/status")
64
+ async def status():
65
+ """Service status endpoint."""
66
+ return {
67
+ "service": "Précis API",
68
+ "version": "0.1.0",
69
+ "model": "Qwen/Qwen2.5-7B-Instruct",
70
+ "model_loaded": False, # TODO: Track actual model state
71
+ "endpoints": ["/", "/health", "/status", "/summarize"]
72
+ }
73
+
74
+
75
+ @app.post("/summarize", response_model=SummarizeResponse)
76
+ async def summarize(request: SummarizeRequest):
77
+ """
78
+ Summarize content from a URL.
79
+
80
+ Currently returns dummy data. Will be implemented with actual model.
81
+ """
82
+ # TODO: Implement actual summarization
83
+ # 1. Fetch content from URL
84
+ # 2. Parse text (YouTube transcript or article)
85
+ # 3. Run through model
86
+ # 4. Return summary
87
+
88
+ dummy_summary = (
89
+ f"This is a placeholder summary for content at {request.url}. "
90
+ "The actual summarization model will be integrated in the next phase. "
91
+ "This summary respects the max_length parameter of {request.max_length} tokens."
92
+ )
93
+
94
+ return SummarizeResponse(
95
+ url=request.url,
96
+ summary=dummy_summary,
97
+ success=True
98
+ )
99
+
100
+
101
+ if __name__ == "__main__":
102
+ import uvicorn
103
+ uvicorn.run(app, host="0.0.0.0", port=8000)
backend/initiate.py DELETED
@@ -1,13 +0,0 @@
1
- from transformers import AutoTokenizer, AutoModelForCausalLM
2
-
3
- MODEL = "Qwen/Qwen2.5-7B-Instruct.gguf.q5_0"
4
-
5
- tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True)
6
-
7
- model = AutoModelForCausalLM.from_pretrained(
8
- MODEL,
9
- device_map="auto",
10
- load_in_4bit=True,
11
- torch_dtype="auto",
12
- trust_remote_code=True
13
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,7 +1,12 @@
1
- pytorch
 
2
  transformers
3
  accelerate
4
  bitsandbytes
5
- summarizer
 
6
  sentencepiece
 
 
7
  fastapi
 
 
1
+ # Core ML
2
+ torch
3
  transformers
4
  accelerate
5
  bitsandbytes
6
+ peft
7
+ datasets
8
  sentencepiece
9
+
10
+ # API
11
  fastapi
12
+ uvicorn
scripts/evaluate.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """CLI evaluation script for Précis."""
3
+
4
+ import argparse
5
+ import logging
6
+ import sys
7
+ from pathlib import Path
8
+
9
+ sys.path.insert(0, str(Path(__file__).parent.parent))
10
+
11
+ from src.config import ModelConfig, DataConfig
12
+ from src.model import load_tokenizer
13
+ from src.tuning.data import create_dummy_data
14
+
15
+ from transformers import AutoModelForCausalLM
16
+ from peft import PeftModel
17
+
18
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ def parse_args():
23
+ parser = argparse.ArgumentParser(description="Evaluate Précis model")
24
+ parser.add_argument("--checkpoint", type=str, required=True, help="Path to model checkpoint")
25
+ parser.add_argument("--num-samples", type=int, default=5, help="Number of samples to evaluate")
26
+ parser.add_argument("--max-new-tokens", type=int, default=256, help="Max tokens to generate")
27
+ return parser.parse_args()
28
+
29
+
30
+ def main():
31
+ args = parse_args()
32
+ config = ModelConfig()
33
+ data_config = DataConfig()
34
+
35
+ logger.info(f"Loading checkpoint from {args.checkpoint}")
36
+ tokenizer = load_tokenizer(config)
37
+
38
+ model = AutoModelForCausalLM.from_pretrained(
39
+ args.checkpoint,
40
+ device_map="auto",
41
+ trust_remote_code=True,
42
+ )
43
+
44
+ # Generate on dummy samples
45
+ samples = create_dummy_data(args.num_samples)
46
+
47
+ for i, sample in enumerate(samples):
48
+ prompt = data_config.format_prompt(sample["text"])
49
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
50
+
51
+ outputs = model.generate(
52
+ **inputs,
53
+ max_new_tokens=args.max_new_tokens,
54
+ do_sample=True,
55
+ temperature=0.7,
56
+ pad_token_id=tokenizer.pad_token_id,
57
+ )
58
+
59
+ generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
60
+ summary = generated[len(prompt):]
61
+
62
+ logger.info(f"\n=== Sample {i+1} ===")
63
+ logger.info(f"Input: {sample['text'][:100]}...")
64
+ logger.info(f"Generated: {summary}")
65
+ logger.info(f"Reference: {sample['summary']}")
66
+
67
+
68
+ if __name__ == "__main__":
69
+ main()
scripts/train.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """CLI training script for Précis."""
3
+
4
+ import argparse
5
+ import logging
6
+ import sys
7
+ from pathlib import Path
8
+
9
+ # Add project root to path
10
+ sys.path.insert(0, str(Path(__file__).parent.parent))
11
+
12
+ from src.config import ModelConfig, TrainingConfig, DataConfig
13
+ from src.model import load_model, load_tokenizer, prepare_for_training
14
+ from src.tuning.lora import apply_lora
15
+ from src.tuning.data import create_dummy_data, prepare_dataset
16
+ from src.tuning.trainer import PrecisTrainer
17
+
18
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ def parse_args():
23
+ parser = argparse.ArgumentParser(description="Train Précis summarization model")
24
+ parser.add_argument("--model-id", type=str, default=None, help="HuggingFace model ID")
25
+ parser.add_argument("--output-dir", type=str, default="./outputs", help="Output directory")
26
+ parser.add_argument("--epochs", type=int, default=3, help="Number of training epochs")
27
+ parser.add_argument("--batch-size", type=int, default=4, help="Batch size")
28
+ parser.add_argument("--learning-rate", type=float, default=2e-4, help="Learning rate")
29
+ parser.add_argument("--lora-r", type=int, default=16, help="LoRA rank")
30
+ parser.add_argument("--dry-run", action="store_true", help="Validate pipeline without training")
31
+ parser.add_argument("--dummy-samples", type=int, default=100, help="Number of dummy samples")
32
+ return parser.parse_args()
33
+
34
+
35
+ def main():
36
+ args = parse_args()
37
+
38
+ # Build configs
39
+ model_config = ModelConfig()
40
+ if args.model_id:
41
+ model_config.model_id = args.model_id
42
+
43
+ training_config = TrainingConfig(
44
+ output_dir=args.output_dir,
45
+ num_epochs=args.epochs,
46
+ batch_size=args.batch_size,
47
+ learning_rate=args.learning_rate,
48
+ lora_r=args.lora_r,
49
+ )
50
+ data_config = DataConfig()
51
+
52
+ if args.dry_run:
53
+ logger.info("=== DRY RUN MODE ===")
54
+ logger.info(f"Model: {model_config.model_id}")
55
+ logger.info(f"Output: {training_config.output_dir}")
56
+ logger.info(f"Epochs: {training_config.num_epochs}, Batch: {training_config.batch_size}")
57
+ logger.info(f"LoRA r: {training_config.lora_r}, alpha: {training_config.lora_alpha}")
58
+
59
+ # Test data pipeline only
60
+ dummy_data = create_dummy_data(5)
61
+ logger.info(f"Dummy data sample: {dummy_data[0]}")
62
+ logger.info("Dry run complete. Pipeline validated.")
63
+ return
64
+
65
+ # Load model and tokenizer
66
+ logger.info("Loading model and tokenizer...")
67
+ tokenizer = load_tokenizer(model_config)
68
+ model = load_model(model_config)
69
+ model = prepare_for_training(model)
70
+ model = apply_lora(model, training_config)
71
+
72
+ # Prepare data
73
+ logger.info("Preparing training data...")
74
+ train_data = create_dummy_data(args.dummy_samples)
75
+ train_dataset = prepare_dataset(train_data, tokenizer, data_config)
76
+
77
+ # Train
78
+ trainer = PrecisTrainer(
79
+ model=model,
80
+ tokenizer=tokenizer,
81
+ train_dataset=train_dataset,
82
+ config=training_config,
83
+ )
84
+ trainer.train()
85
+ trainer.save()
86
+
87
+ logger.info("Training complete!")
88
+
89
+
90
+ if __name__ == "__main__":
91
+ main()
src/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Précis — Model loading, configuration, and fine-tuning utilities.
3
+ """
4
+
5
+ from src.config import ModelConfig, TrainingConfig, DataConfig
6
+ from src.model import load_model, load_tokenizer, prepare_for_training
7
+
8
+ __version__ = "0.1.0"
9
+ __all__ = [
10
+ "ModelConfig",
11
+ "TrainingConfig",
12
+ "DataConfig",
13
+ "load_model",
14
+ "load_tokenizer",
15
+ "prepare_for_training",
16
+ ]
src/config.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Configuration management for Précis."""
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Optional, List
5
+
6
+
7
+ @dataclass
8
+ class ModelConfig:
9
+ """Configuration for model loading and quantization."""
10
+ model_id: str = "Qwen/Qwen2.5-7B-Instruct"
11
+ load_in_4bit: bool = True
12
+ load_in_8bit: bool = False
13
+ bnb_4bit_compute_dtype: str = "float16"
14
+ bnb_4bit_quant_type: str = "nf4"
15
+ bnb_4bit_use_double_quant: bool = True
16
+ device_map: str = "auto"
17
+ trust_remote_code: bool = True
18
+ cache_dir: Optional[str] = None
19
+
20
+ def __post_init__(self):
21
+ if self.load_in_4bit and self.load_in_8bit:
22
+ raise ValueError("Cannot enable both 4-bit and 8-bit quantization")
23
+
24
+
25
+ @dataclass
26
+ class TrainingConfig:
27
+ """Configuration for LoRA fine-tuning."""
28
+ lora_r: int = 16
29
+ lora_alpha: int = 32
30
+ lora_dropout: float = 0.05
31
+ lora_target_modules: List[str] = field(
32
+ default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"]
33
+ )
34
+ learning_rate: float = 2e-4
35
+ batch_size: int = 4
36
+ gradient_accumulation_steps: int = 4
37
+ num_epochs: int = 3
38
+ warmup_ratio: float = 0.03
39
+ weight_decay: float = 0.01
40
+ max_grad_norm: float = 1.0
41
+ max_seq_length: int = 2048
42
+ optim: str = "paged_adamw_32bit"
43
+ save_steps: int = 100
44
+ logging_steps: int = 10
45
+ eval_steps: int = 100
46
+ output_dir: str = "./outputs"
47
+ seed: int = 42
48
+
49
+
50
+ @dataclass
51
+ class DataConfig:
52
+ """Configuration for dataset loading and preprocessing."""
53
+ train_file: Optional[str] = None
54
+ eval_file: Optional[str] = None
55
+ input_column: str = "text"
56
+ target_column: str = "summary"
57
+ max_input_length: int = 1536
58
+ max_target_length: int = 512
59
+ train_split: float = 0.9
60
+ prompt_template: str = (
61
+ "Summarize the following document:\n\n"
62
+ "### Document:\n{input}\n\n"
63
+ "### Summary:\n"
64
+ )
65
+
66
+ def format_prompt(self, text: str) -> str:
67
+ return self.prompt_template.format(input=text)
src/model.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Model loading utilities for Précis."""
2
+
3
+ import logging
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ from transformers import (
8
+ AutoModelForCausalLM,
9
+ AutoTokenizer,
10
+ BitsAndBytesConfig,
11
+ PreTrainedModel,
12
+ PreTrainedTokenizer,
13
+ )
14
+
15
+ from src.config import ModelConfig
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def get_quantization_config(config: ModelConfig) -> Optional[BitsAndBytesConfig]:
21
+ """Create BitsAndBytes quantization configuration."""
22
+ if config.load_in_4bit:
23
+ compute_dtype = getattr(torch, config.bnb_4bit_compute_dtype)
24
+ return BitsAndBytesConfig(
25
+ load_in_4bit=True,
26
+ bnb_4bit_compute_dtype=compute_dtype,
27
+ bnb_4bit_quant_type=config.bnb_4bit_quant_type,
28
+ bnb_4bit_use_double_quant=config.bnb_4bit_use_double_quant,
29
+ )
30
+ elif config.load_in_8bit:
31
+ return BitsAndBytesConfig(load_in_8bit=True)
32
+ return None
33
+
34
+
35
+ def load_tokenizer(config: Optional[ModelConfig] = None) -> PreTrainedTokenizer:
36
+ """Load and configure the tokenizer."""
37
+ if config is None:
38
+ config = ModelConfig()
39
+
40
+ logger.info(f"Loading tokenizer: {config.model_id}")
41
+ tokenizer = AutoTokenizer.from_pretrained(
42
+ config.model_id,
43
+ trust_remote_code=config.trust_remote_code,
44
+ cache_dir=config.cache_dir,
45
+ )
46
+
47
+ if tokenizer.pad_token is None:
48
+ tokenizer.pad_token = tokenizer.eos_token
49
+ tokenizer.pad_token_id = tokenizer.eos_token_id
50
+ tokenizer.padding_side = "right"
51
+
52
+ return tokenizer
53
+
54
+
55
+ def load_model(config: Optional[ModelConfig] = None) -> PreTrainedModel:
56
+ """Load the base model with optional quantization."""
57
+ if config is None:
58
+ config = ModelConfig()
59
+
60
+ logger.info(f"Loading model: {config.model_id}")
61
+ quantization_config = get_quantization_config(config)
62
+
63
+ model = AutoModelForCausalLM.from_pretrained(
64
+ config.model_id,
65
+ quantization_config=quantization_config,
66
+ device_map=config.device_map,
67
+ trust_remote_code=config.trust_remote_code,
68
+ cache_dir=config.cache_dir,
69
+ torch_dtype=torch.float16 if quantization_config else "auto",
70
+ )
71
+
72
+ logger.info(f"Model loaded. Parameters: {model.num_parameters():,}")
73
+ return model
74
+
75
+
76
+ def prepare_for_training(model: PreTrainedModel, gradient_checkpointing: bool = True) -> PreTrainedModel:
77
+ """Prepare model for training with gradient checkpointing and k-bit setup."""
78
+ if gradient_checkpointing:
79
+ model.gradient_checkpointing_enable()
80
+
81
+ if getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False):
82
+ from peft import prepare_model_for_kbit_training
83
+ model = prepare_model_for_kbit_training(model)
84
+
85
+ return model
86
+
87
+
88
+ def load_model_and_tokenizer(config: Optional[ModelConfig] = None) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
89
+ """Load both model and tokenizer."""
90
+ if config is None:
91
+ config = ModelConfig()
92
+ return load_model(config), load_tokenizer(config)
src/tuning/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tuning subpackage for Précis."""
2
+
3
+ from src.tuning.lora import get_lora_config, apply_lora, merge_and_save
4
+ from src.tuning.data import SummarizationDataset, prepare_dataset, create_dummy_data
5
+ from src.tuning.trainer import PrecisTrainer
6
+
7
+ __all__ = [
8
+ "get_lora_config",
9
+ "apply_lora",
10
+ "merge_and_save",
11
+ "SummarizationDataset",
12
+ "prepare_dataset",
13
+ "create_dummy_data",
14
+ "PrecisTrainer",
15
+ ]
src/tuning/data.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Data preparation utilities for training."""
2
+
3
+ import logging
4
+ from typing import Dict, List, Optional, Any
5
+
6
+ from torch.utils.data import Dataset
7
+ from transformers import PreTrainedTokenizer
8
+
9
+ from src.config import DataConfig
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class SummarizationDataset(Dataset):
15
+ """PyTorch Dataset for summarization training."""
16
+
17
+ def __init__(
18
+ self,
19
+ data: List[Dict[str, str]],
20
+ tokenizer: PreTrainedTokenizer,
21
+ config: Optional[DataConfig] = None,
22
+ ):
23
+ self.data = data
24
+ self.tokenizer = tokenizer
25
+ self.config = config or DataConfig()
26
+
27
+ def __len__(self) -> int:
28
+ return len(self.data)
29
+
30
+ def __getitem__(self, idx: int) -> Dict[str, Any]:
31
+ item = self.data[idx]
32
+ prompt = self.config.format_prompt(item[self.config.input_column])
33
+ full_text = prompt + item[self.config.target_column] + self.tokenizer.eos_token
34
+
35
+ encoding = self.tokenizer(
36
+ full_text,
37
+ truncation=True,
38
+ max_length=self.config.max_input_length + self.config.max_target_length,
39
+ padding="max_length",
40
+ return_tensors="pt",
41
+ )
42
+
43
+ return {
44
+ "input_ids": encoding["input_ids"].squeeze(),
45
+ "attention_mask": encoding["attention_mask"].squeeze(),
46
+ "labels": encoding["input_ids"].squeeze(),
47
+ }
48
+
49
+
50
+ def create_dummy_data(num_samples: int = 10) -> List[Dict[str, str]]:
51
+ """Generate dummy data for testing the training pipeline."""
52
+ samples = []
53
+ for i in range(num_samples):
54
+ samples.append({
55
+ "text": f"This is sample document {i}. It contains information about topic {i % 3}. "
56
+ f"The document discusses various aspects and provides detailed analysis. "
57
+ f"Key points include methodology, results, and conclusions for study {i}.",
58
+ "summary": f"Document {i} analyzes topic {i % 3}, covering methodology, results, and conclusions.",
59
+ })
60
+ logger.info(f"Created {num_samples} dummy samples")
61
+ return samples
62
+
63
+
64
+ def prepare_dataset(
65
+ data: List[Dict[str, str]],
66
+ tokenizer: PreTrainedTokenizer,
67
+ config: Optional[DataConfig] = None,
68
+ ) -> SummarizationDataset:
69
+ """Prepare dataset for training."""
70
+ return SummarizationDataset(data, tokenizer, config)
src/tuning/lora.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """LoRA/PEFT configuration and utilities."""
2
+
3
+ import logging
4
+ from pathlib import Path
5
+ from typing import Optional
6
+
7
+ from peft import LoraConfig, get_peft_model, PeftModel, TaskType
8
+ from transformers import PreTrainedModel
9
+
10
+ from src.config import TrainingConfig
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ def get_lora_config(config: Optional[TrainingConfig] = None) -> LoraConfig:
16
+ """Create LoRA configuration for summarization task."""
17
+ if config is None:
18
+ config = TrainingConfig()
19
+
20
+ return LoraConfig(
21
+ r=config.lora_r,
22
+ lora_alpha=config.lora_alpha,
23
+ lora_dropout=config.lora_dropout,
24
+ target_modules=config.lora_target_modules,
25
+ bias="none",
26
+ task_type=TaskType.CAUSAL_LM,
27
+ )
28
+
29
+
30
+ def apply_lora(model: PreTrainedModel, config: Optional[TrainingConfig] = None) -> PeftModel:
31
+ """Apply LoRA adapters to model."""
32
+ lora_config = get_lora_config(config)
33
+ logger.info(f"Applying LoRA with r={lora_config.r}, alpha={lora_config.lora_alpha}")
34
+
35
+ peft_model = get_peft_model(model, lora_config)
36
+ peft_model.print_trainable_parameters()
37
+
38
+ return peft_model
39
+
40
+
41
+ def merge_and_save(model: PeftModel, output_path: str, tokenizer=None) -> None:
42
+ """Merge LoRA weights into base model and save."""
43
+ output_dir = Path(output_path)
44
+ output_dir.mkdir(parents=True, exist_ok=True)
45
+
46
+ logger.info("Merging LoRA weights...")
47
+ merged_model = model.merge_and_unload()
48
+
49
+ logger.info(f"Saving merged model to {output_dir}")
50
+ merged_model.save_pretrained(output_dir)
51
+
52
+ if tokenizer:
53
+ tokenizer.save_pretrained(output_dir)
src/tuning/trainer.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Training orchestration for Précis."""
2
+
3
+ import logging
4
+ from pathlib import Path
5
+ from typing import Optional
6
+
7
+ from transformers import (
8
+ Trainer,
9
+ TrainingArguments,
10
+ PreTrainedModel,
11
+ PreTrainedTokenizer,
12
+ DataCollatorForLanguageModeling,
13
+ )
14
+ from torch.utils.data import Dataset
15
+
16
+ from src.config import TrainingConfig
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class PrecisTrainer:
22
+ """Wrapper around HuggingFace Trainer for summarization fine-tuning."""
23
+
24
+ def __init__(
25
+ self,
26
+ model: PreTrainedModel,
27
+ tokenizer: PreTrainedTokenizer,
28
+ train_dataset: Dataset,
29
+ eval_dataset: Optional[Dataset] = None,
30
+ config: Optional[TrainingConfig] = None,
31
+ ):
32
+ self.model = model
33
+ self.tokenizer = tokenizer
34
+ self.train_dataset = train_dataset
35
+ self.eval_dataset = eval_dataset
36
+ self.config = config or TrainingConfig()
37
+
38
+ self.training_args = self._create_training_args()
39
+ self.trainer = self._create_trainer()
40
+
41
+ def _create_training_args(self) -> TrainingArguments:
42
+ """Create HuggingFace TrainingArguments from config."""
43
+ return TrainingArguments(
44
+ output_dir=self.config.output_dir,
45
+ num_train_epochs=self.config.num_epochs,
46
+ per_device_train_batch_size=self.config.batch_size,
47
+ gradient_accumulation_steps=self.config.gradient_accumulation_steps,
48
+ learning_rate=self.config.learning_rate,
49
+ warmup_ratio=self.config.warmup_ratio,
50
+ weight_decay=self.config.weight_decay,
51
+ max_grad_norm=self.config.max_grad_norm,
52
+ optim=self.config.optim,
53
+ logging_steps=self.config.logging_steps,
54
+ save_steps=self.config.save_steps,
55
+ eval_steps=self.config.eval_steps if self.eval_dataset else None,
56
+ evaluation_strategy="steps" if self.eval_dataset else "no",
57
+ save_total_limit=3,
58
+ load_best_model_at_end=bool(self.eval_dataset),
59
+ seed=self.config.seed,
60
+ fp16=True,
61
+ report_to="none",
62
+ )
63
+
64
+ def _create_trainer(self) -> Trainer:
65
+ """Create HuggingFace Trainer instance."""
66
+ data_collator = DataCollatorForLanguageModeling(
67
+ tokenizer=self.tokenizer,
68
+ mlm=False,
69
+ )
70
+
71
+ return Trainer(
72
+ model=self.model,
73
+ args=self.training_args,
74
+ train_dataset=self.train_dataset,
75
+ eval_dataset=self.eval_dataset,
76
+ data_collator=data_collator,
77
+ )
78
+
79
+ def train(self) -> None:
80
+ """Execute training loop."""
81
+ logger.info("Starting training...")
82
+ self.trainer.train()
83
+ logger.info("Training complete.")
84
+
85
+ def evaluate(self) -> dict:
86
+ """Run evaluation and return metrics."""
87
+ if self.eval_dataset is None:
88
+ logger.warning("No eval dataset provided")
89
+ return {}
90
+
91
+ logger.info("Running evaluation...")
92
+ return self.trainer.evaluate()
93
+
94
+ def save(self, output_path: Optional[str] = None) -> None:
95
+ """Save model checkpoint."""
96
+ path = output_path or self.config.output_dir
97
+ Path(path).mkdir(parents=True, exist_ok=True)
98
+ self.trainer.save_model(path)
99
+ self.tokenizer.save_pretrained(path)
100
+ logger.info(f"Model saved to {path}")