Nutnell commited on
Commit
a89712a
·
verified ·
1 Parent(s): f2dc681

Update fine_tune.py

Browse files
Files changed (1) hide show
  1. fine_tune.py +19 -12
fine_tune.py CHANGED
@@ -1,4 +1,3 @@
1
- # fine_tune.py
2
  import os
3
  import torch
4
  from datasets import load_dataset
@@ -11,19 +10,19 @@ from transformers import (
11
  from peft import LoraConfig, PeftModel
12
  from trl import SFTTrainer
13
  from fastapi import FastAPI
 
14
  import uvicorn
15
 
16
-
17
  base_model_name = "unsloth/llama-3-8b-Instruct-bnb-4bit"
18
-
19
- output_dir = "/data/fine_tuning"
20
  dataset_path = "dataset.jsonl"
21
 
22
- # Initialize model and tokenizer variables
23
  model = None
24
  tokenizer = None
25
 
26
- # Training Logic
27
  # Check if a fine-tuned model adapter already exists
28
  if not os.path.exists(os.path.join(output_dir, 'adapter_config.json')):
29
  print("No fine-tuned model found. Starting training...")
@@ -70,7 +69,7 @@ if not os.path.exists(os.path.join(output_dir, 'adapter_config.json')):
70
  group_by_length=True,
71
  lr_scheduler_type="linear",
72
  push_to_hub=True,
73
- hub_model_id="Nutnell/DirectEd-AI",
74
  )
75
 
76
  # Initialize Trainer
@@ -91,7 +90,7 @@ if not os.path.exists(os.path.join(output_dir, 'adapter_config.json')):
91
 
92
  model = trainer.model
93
 
94
- # Inference Logic
95
  # If training did not run, load the existing model
96
  else:
97
  print("Found existing fine-tuned model. Loading for inference...")
@@ -107,22 +106,30 @@ else:
107
  tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
108
 
109
 
110
- # Create Inference Pipeline
111
  print("Setting up inference pipeline...")
112
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device_map="auto")
113
  print("Inference pipeline ready.")
114
 
115
- # FastAPI App
 
 
 
 
 
 
116
  app = FastAPI(title="Fine-tuned LLaMA API")
117
 
118
  @app.get("/")
119
  def home():
120
  return {"status": "ok", "message": "Fine-tuned LLaMA is ready."}
121
 
 
122
  @app.post("/generate")
123
- def generate(prompt: str):
 
 
124
 
125
- formatted_prompt = f"<|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
126
  outputs = pipe(formatted_prompt, max_new_tokens=200, do_sample=True, temperature=0.7)
127
  return {"response": outputs[0]["generated_text"]}
128
 
 
 
1
  import os
2
  import torch
3
  from datasets import load_dataset
 
10
  from peft import LoraConfig, PeftModel
11
  from trl import SFTTrainer
12
  from fastapi import FastAPI
13
+ from pydantic import BaseModel # 1. ADD THIS IMPORT
14
  import uvicorn
15
 
16
+ # --- Configuration ---
17
  base_model_name = "unsloth/llama-3-8b-Instruct-bnb-4bit"
18
+ output_dir = "/data/fine_tuning"
 
19
  dataset_path = "dataset.jsonl"
20
 
21
+ # --- Initialize model and tokenizer variables ---
22
  model = None
23
  tokenizer = None
24
 
25
+ # --- Training Logic ---
26
  # Check if a fine-tuned model adapter already exists
27
  if not os.path.exists(os.path.join(output_dir, 'adapter_config.json')):
28
  print("No fine-tuned model found. Starting training...")
 
69
  group_by_length=True,
70
  lr_scheduler_type="linear",
71
  push_to_hub=True,
72
+ hub_model_id="Nutnell/DirectEd-AI",
73
  )
74
 
75
  # Initialize Trainer
 
90
 
91
  model = trainer.model
92
 
93
+ # --- Inference Logic ---
94
  # If training did not run, load the existing model
95
  else:
96
  print("Found existing fine-tuned model. Loading for inference...")
 
106
  tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
107
 
108
 
109
+ # --- Create Inference Pipeline ---
110
  print("Setting up inference pipeline...")
111
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device_map="auto")
112
  print("Inference pipeline ready.")
113
 
114
+
115
+ # --- FastAPI App ---
116
+
117
+ # 2. DEFINE THE PYDANTIC MODEL FOR THE REQUEST BODY
118
+ class GenerateRequest(BaseModel):
119
+ prompt: str
120
+
121
  app = FastAPI(title="Fine-tuned LLaMA API")
122
 
123
  @app.get("/")
124
  def home():
125
  return {"status": "ok", "message": "Fine-tuned LLaMA is ready."}
126
 
127
+ # 3. UPDATE THE GENERATE FUNCTION TO USE THE PYDANTIC MODEL
128
  @app.post("/generate")
129
+ def generate(request: GenerateRequest):
130
+ # Access the prompt from the request object
131
+ formatted_prompt = f"<|start_header_id|>user<|end_header_id|>\n\n{request.prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
132
 
 
133
  outputs = pipe(formatted_prompt, max_new_tokens=200, do_sample=True, temperature=0.7)
134
  return {"response": outputs[0]["generated_text"]}
135