mpt-space / app.py
hello-ram's picture
Update app.py
777ec21 verified
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch
app = FastAPI()
# 1. Base model
BASE_MODEL = "gpt2"
# 2. LoRA adapter repo
LORA_REPO = "hello-ram/unsolth_gpt.20"
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
print("Loading base model...")
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.float16,
device_map="auto",
)
print("Applying LoRA adapter...")
model = PeftModel.from_pretrained(
base_model,
LORA_REPO,
device_map="auto"
)
model.eval()
@app.get("/")
async def root():
return {"msg": "LoRA model running", "endpoints": ["/status", "/generate"]}
@app.get("/status")
async def status():
return {
"status": "ok",
"base_model": BASE_MODEL,
"lora_model": LORA_REPO,
"device": str(model.device)
}
class InputText(BaseModel):
text: str
@app.post("/generate")
async def generate_text(data: InputText):
inputs = tokenizer(data.text, return_tensors="pt").to(model.device)
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=200,
temperature=0.7
)
text = tokenizer.decode(output[0], skip_special_tokens=True)
return {"response": text}