Nutnell commited on
Commit
6b517e7
·
verified ·
1 Parent(s): 4146ecf

Update fine_tune.py

Browse files
Files changed (1) hide show
  1. fine_tune.py +25 -23
fine_tune.py CHANGED
@@ -17,20 +17,18 @@ import uvicorn
17
  base_model_name = "unsloth/llama-3-8b-Instruct-bnb-4bit"
18
  output_dir = "/data/fine_tuning"
19
  dataset_path = "dataset.jsonl"
 
20
 
21
  # --- Initialize model and tokenizer variables ---
22
  model = None
23
  tokenizer = None
24
 
25
  # --- Training Logic ---
26
- # Check if a fine-tuned model adapter already exists
27
  if not os.path.exists(os.path.join(output_dir, 'adapter_config.json')):
28
  print("No fine-tuned model found. Starting training...")
29
 
30
- # Load dataset
31
  dataset = load_dataset("json", data_files=dataset_path, split="train")
32
 
33
- # Load base model for training
34
  model = AutoModelForCausalLM.from_pretrained(
35
  base_model_name,
36
  device_map="auto",
@@ -53,7 +51,6 @@ if not os.path.exists(os.path.join(output_dir, 'adapter_config.json')):
53
  target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
54
  )
55
 
56
- # Training args
57
  training_arguments = TrainingArguments(
58
  output_dir=output_dir,
59
  num_train_epochs=1,
@@ -69,10 +66,9 @@ if not os.path.exists(os.path.join(output_dir, 'adapter_config.json')):
69
  group_by_length=True,
70
  lr_scheduler_type="linear",
71
  push_to_hub=True,
72
- hub_model_id="Nutnell/direct-ed-finetune-job",
73
  )
74
 
75
- # Initialize Trainer
76
  trainer = SFTTrainer(
77
  model=model,
78
  train_dataset=dataset,
@@ -81,40 +77,31 @@ if not os.path.exists(os.path.join(output_dir, 'adapter_config.json')):
81
  args=training_arguments,
82
  )
83
 
84
- # Train the model
85
  trainer.train()
86
-
87
- # Save the trained adapter
88
  trainer.model.save_pretrained(output_dir)
89
  print(f"Fine-tuned model adapter saved to {output_dir}")
90
 
 
 
 
91
  model = trainer.model
92
 
93
- # --- Inference Logic ---
94
- # If training did not run, load the existing model
95
  else:
96
  print("Found existing fine-tuned model. Loading for inference...")
97
-
98
- # Load the base model
99
  base_model = AutoModelForCausalLM.from_pretrained(
100
  base_model_name,
101
  device_map="auto",
102
  trust_remote_code=True,
103
  )
104
- # Apply the PEFT adapter
105
  model = PeftModel.from_pretrained(base_model, output_dir)
106
  tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
107
 
108
-
109
- # --- Create Inference Pipeline ---
110
  print("Setting up inference pipeline...")
111
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device_map="auto")
112
  print("Inference pipeline ready.")
113
 
114
-
115
  # --- FastAPI App ---
116
-
117
- # 2. DEFINE THE PYDANTIC MODEL FOR THE REQUEST BODY
118
  class GenerateRequest(BaseModel):
119
  prompt: str
120
 
@@ -124,14 +111,29 @@ app = FastAPI(title="Fine-tuned LLaMA API")
124
  def home():
125
  return {"status": "ok", "message": "Fine-tuned LLaMA is ready."}
126
 
127
- # 3. UPDATE THE GENERATE FUNCTION TO USE THE PYDANTIC MODEL
128
  @app.post("/generate")
129
  def generate(request: GenerateRequest):
130
- # Access the prompt from the request object
131
  formatted_prompt = f"<|start_header_id|>user<|end_header_id|>\n\n{request.prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
132
-
133
  outputs = pipe(formatted_prompt, max_new_tokens=200, do_sample=True, temperature=0.7)
134
  return {"response": outputs[0]["generated_text"]}
135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  if __name__ == "__main__":
137
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
17
  base_model_name = "unsloth/llama-3-8b-Instruct-bnb-4bit"
18
  output_dir = "/data/fine_tuning"
19
  dataset_path = "dataset.jsonl"
20
+ hub_model_id = "Nutnell/direct-ed-finetune-job"
21
 
22
  # --- Initialize model and tokenizer variables ---
23
  model = None
24
  tokenizer = None
25
 
26
  # --- Training Logic ---
 
27
  if not os.path.exists(os.path.join(output_dir, 'adapter_config.json')):
28
  print("No fine-tuned model found. Starting training...")
29
 
 
30
  dataset = load_dataset("json", data_files=dataset_path, split="train")
31
 
 
32
  model = AutoModelForCausalLM.from_pretrained(
33
  base_model_name,
34
  device_map="auto",
 
51
  target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
52
  )
53
 
 
54
  training_arguments = TrainingArguments(
55
  output_dir=output_dir,
56
  num_train_epochs=1,
 
66
  group_by_length=True,
67
  lr_scheduler_type="linear",
68
  push_to_hub=True,
69
+ hub_model_id=hub_model_id,
70
  )
71
 
 
72
  trainer = SFTTrainer(
73
  model=model,
74
  train_dataset=dataset,
 
77
  args=training_arguments,
78
  )
79
 
 
80
  trainer.train()
 
 
81
  trainer.model.save_pretrained(output_dir)
82
  print(f"Fine-tuned model adapter saved to {output_dir}")
83
 
84
+ # Push trained model to Hub
85
+ trainer.push_to_hub()
86
+
87
  model = trainer.model
88
 
 
 
89
  else:
90
  print("Found existing fine-tuned model. Loading for inference...")
 
 
91
  base_model = AutoModelForCausalLM.from_pretrained(
92
  base_model_name,
93
  device_map="auto",
94
  trust_remote_code=True,
95
  )
 
96
  model = PeftModel.from_pretrained(base_model, output_dir)
97
  tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
98
 
99
+ # --- Inference Pipeline ---
 
100
  print("Setting up inference pipeline...")
101
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device_map="auto")
102
  print("Inference pipeline ready.")
103
 
 
104
  # --- FastAPI App ---
 
 
105
  class GenerateRequest(BaseModel):
106
  prompt: str
107
 
 
111
  def home():
112
  return {"status": "ok", "message": "Fine-tuned LLaMA is ready."}
113
 
 
114
  @app.post("/generate")
115
  def generate(request: GenerateRequest):
 
116
  formatted_prompt = f"<|start_header_id|>user<|end_header_id|>\n\n{request.prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
 
117
  outputs = pipe(formatted_prompt, max_new_tokens=200, do_sample=True, temperature=0.7)
118
  return {"response": outputs[0]["generated_text"]}
119
 
120
+ # --- Extra utility endpoints ---
121
+ @app.get("/list-files")
122
+ def list_files():
123
+ files = []
124
+ for root, _, filenames in os.walk(output_dir):
125
+ for fname in filenames:
126
+ files.append(os.path.relpath(os.path.join(root, fname), output_dir))
127
+ return {"files": files}
128
+
129
+ @app.post("/push-to-hub")
130
+ def push_to_hub():
131
+ try:
132
+ model.push_to_hub(hub_model_id)
133
+ tokenizer.push_to_hub(hub_model_id)
134
+ return {"status": "success", "message": f"Pushed to Hugging Face Hub ({hub_model_id})"}
135
+ except Exception as e:
136
+ return {"status": "error", "message": str(e)}
137
+
138
  if __name__ == "__main__":
139
+ uvicorn.run(app, host="0.0.0.0", port=7860)