File size: 1,390 Bytes
7ed8e50 777ec21 7ed8e50 777ec21 7ed8e50 777ec21 7ed8e50 777ec21 7ed8e50 777ec21 7ed8e50 777ec21 7ed8e50 777ec21 7ed8e50 777ec21 7ed8e50 777ec21 7ed8e50 777ec21 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 | from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch
app = FastAPI()
# 1. Base model
BASE_MODEL = "gpt2"
# 2. LoRA adapter repo
LORA_REPO = "hello-ram/unsolth_gpt.20"
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
print("Loading base model...")
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.float16,
device_map="auto",
)
print("Applying LoRA adapter...")
model = PeftModel.from_pretrained(
base_model,
LORA_REPO,
device_map="auto"
)
model.eval()
@app.get("/")
async def root():
return {"msg": "LoRA model running", "endpoints": ["/status", "/generate"]}
@app.get("/status")
async def status():
return {
"status": "ok",
"base_model": BASE_MODEL,
"lora_model": LORA_REPO,
"device": str(model.device)
}
class InputText(BaseModel):
text: str
@app.post("/generate")
async def generate_text(data: InputText):
inputs = tokenizer(data.text, return_tensors="pt").to(model.device)
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=200,
temperature=0.7
)
text = tokenizer.decode(output[0], skip_special_tokens=True)
return {"response": text}
|