| | from fastapi import FastAPI |
| | from pydantic import BaseModel |
| | from transformers import AutoTokenizer, AutoModelForCausalLM |
| | from peft import PeftModel |
| | import torch |
| |
|
| | app = FastAPI() |
| |
|
| | |
| | BASE_MODEL = "gpt2" |
| |
|
| | |
| | LORA_REPO = "hello-ram/unsolth_gpt.20" |
| |
|
| |
|
| | print("Loading tokenizer...") |
| | tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) |
| |
|
| | print("Loading base model...") |
| | base_model = AutoModelForCausalLM.from_pretrained( |
| | BASE_MODEL, |
| | torch_dtype=torch.float16, |
| | device_map="auto", |
| | ) |
| |
|
| | print("Applying LoRA adapter...") |
| | model = PeftModel.from_pretrained( |
| | base_model, |
| | LORA_REPO, |
| | device_map="auto" |
| | ) |
| |
|
| | model.eval() |
| |
|
| |
|
| | @app.get("/") |
| | async def root(): |
| | return {"msg": "LoRA model running", "endpoints": ["/status", "/generate"]} |
| |
|
| |
|
| | @app.get("/status") |
| | async def status(): |
| | return { |
| | "status": "ok", |
| | "base_model": BASE_MODEL, |
| | "lora_model": LORA_REPO, |
| | "device": str(model.device) |
| | } |
| |
|
| |
|
| | class InputText(BaseModel): |
| | text: str |
| |
|
| |
|
| | @app.post("/generate") |
| | async def generate_text(data: InputText): |
| | inputs = tokenizer(data.text, return_tensors="pt").to(model.device) |
| |
|
| | with torch.no_grad(): |
| | output = model.generate( |
| | **inputs, |
| | max_new_tokens=200, |
| | temperature=0.7 |
| | ) |
| |
|
| | text = tokenizer.decode(output[0], skip_special_tokens=True) |
| | return {"response": text} |
| |
|