Nutnell commited on
Commit
bae97e6
·
verified ·
1 Parent(s): 6b517e7

Update fine_tune.py

Browse files
Files changed (1) hide show
  1. fine_tune.py +81 -28
fine_tune.py CHANGED
@@ -105,35 +105,88 @@ print("Inference pipeline ready.")
105
  class GenerateRequest(BaseModel):
106
  prompt: str
107
 
108
- app = FastAPI(title="Fine-tuned LLaMA API")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  @app.get("/")
111
- def home():
112
- return {"status": "ok", "message": "Fine-tuned LLaMA is ready."}
 
113
 
114
  @app.post("/generate")
115
- def generate(request: GenerateRequest):
116
- formatted_prompt = f"<|start_header_id|>user<|end_header_id|>\n\n{request.prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
117
- outputs = pipe(formatted_prompt, max_new_tokens=200, do_sample=True, temperature=0.7)
118
- return {"response": outputs[0]["generated_text"]}
119
-
120
- # --- Extra utility endpoints ---
121
- @app.get("/list-files")
122
- def list_files():
123
- files = []
124
- for root, _, filenames in os.walk(output_dir):
125
- for fname in filenames:
126
- files.append(os.path.relpath(os.path.join(root, fname), output_dir))
127
- return {"files": files}
128
-
129
- @app.post("/push-to-hub")
130
- def push_to_hub():
131
- try:
132
- model.push_to_hub(hub_model_id)
133
- tokenizer.push_to_hub(hub_model_id)
134
- return {"status": "success", "message": f"Pushed to Hugging Face Hub ({hub_model_id})"}
135
- except Exception as e:
136
- return {"status": "error", "message": str(e)}
137
-
138
- if __name__ == "__main__":
139
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  class GenerateRequest(BaseModel):
106
  prompt: str
107
 
108
+ app = FastAPI(
109
+ title="DirectEd AI Assistant",
110
+ version="1.0",
111
+ description="API for fine-tuned DirectEd AI chatbot."
112
+ )
113
+
114
+ # --- Load Model + Tokenizer ---
115
+ try:
116
+ tokenizer = AutoTokenizer.from_pretrained(base_model_name)
117
+
118
+ model = AutoModelForCausalLM.from_pretrained(
119
+ base_model_name,
120
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
121
+ device_map="auto"
122
+ )
123
 
124
+ if os.path.exists(output_dir):
125
+ print(f"Loading adapter from {output_dir}")
126
+ model = PeftModel.from_pretrained(model, output_dir)
127
+ else:
128
+ print("⚠️ No adapter folder found, using base model only")
129
+
130
+ except Exception as e:
131
+ print("❌ Model load failed:", e)
132
+ model, tokenizer = None, None
133
+
134
+
135
+ # --- Routes ---
136
  @app.get("/")
137
+ def health():
138
+ return {"status": "ok", "message": "DirectEd AI Space running."}
139
+
140
 
141
  @app.post("/generate")
142
+ def generate(prompt: str, max_new_tokens: int = 200):
143
+ if model is None or tokenizer is None:
144
+ return {"error": "Model not loaded."}
145
+
146
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
147
+ outputs = model.generate(
148
+ **inputs,
149
+ max_new_tokens=max_new_tokens,
150
+ do_sample=True,
151
+ top_k=50,
152
+ top_p=0.9
153
+ )
154
+ text = tokenizer.decode(outputs[0], skip_special_tokens=True)
155
+ return {"response": text}
156
+
157
+
158
+ @app.get("/list_adapter")
159
+ def list_adapter():
160
+ """List adapter files in output_dir"""
161
+ if os.path.exists(output_dir):
162
+ files = os.listdir(output_dir)
163
+ return {"adapter_files": files}
164
+ return {"adapter_files": [], "message": "No adapter directory found."}
165
+
166
+
167
+ @app.post("/upload_adapter")
168
+ def upload_adapter(file: UploadFile = File(...)):
169
+ """Upload adapter files (e.g. adapter_config.json, adapter_model.bin)"""
170
+ os.makedirs(output_dir, exist_ok=True)
171
+ save_path = os.path.join(output_dir, file.filename)
172
+ with open(save_path, "wb") as buffer:
173
+ shutil.copyfileobj(file.file, buffer)
174
+ return {"status": "success", "filename": file.filename}
175
+
176
+
177
+ @app.post("/push_adapter")
178
+ def push_adapter():
179
+ """Push adapter folder to Hugging Face Hub"""
180
+ if not os.path.exists(output_dir):
181
+ return {"error": "No adapter folder found."}
182
+
183
+ files = os.listdir(output_dir)
184
+ if not files:
185
+ return {"error": "Adapter folder is empty."}
186
+
187
+ upload_folder(
188
+ repo_id=hub_repo_id,
189
+ folder_path=output_dir,
190
+ commit_message="Upload LoRA adapter from Space"
191
+ )
192
+ return {"status": "uploaded", "repo": f"https://huggingface.co/{hub_repo_id}", "files": files}