CheckInAPI / main.py
ethnmcl's picture
Update main.py
1bb585d verified
raw
history blame
4.19 kB
import os
from typing import Dict, Any
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
# === Config ===
MODEL_ID = os.getenv("MODEL_ID", "ethnmcl/checkin-lora-gpt2")
HF_TOKEN = os.getenv("HF_TOKEN") # if the repo is private, set this in Secrets
app = FastAPI(title="Check-in GPT-2 API", version="1.1.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"],
)
# Choose device: GPU index 0 if available else CPU
device = 0 if torch.cuda.is_available() else -1
# === Load tokenizer ===
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# === Load model (supports plain CausalLM repos AND PEFT LoRA adapters) ===
# Strategy:
# 1) Try plain AutoModelForCausalLM
# 2) If that fails (likely LoRA-only repo), try PEFT AutoPeftModelForCausalLM and merge
_model = None
_merged = False
try:
_model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
token=HF_TOKEN,
# Use 'dtype' not deprecated 'torch_dtype'
dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
)
except Exception as e_plain:
# Fall back to PEFT path
try:
from peft import AutoPeftModelForCausalLM
_model = AutoPeftModelForCausalLM.from_pretrained(
MODEL_ID,
token=HF_TOKEN,
dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
)
# Merge LoRA into base weights so inference behaves like a standard CausalLM
try:
_model = _model.merge_and_unload()
_merged = True
except Exception:
# If merge not available, we still can run with adapters active
_merged = False
except Exception as e_peft:
raise RuntimeError(
f"Failed to load model '{MODEL_ID}'. "
f"Plain load error: {e_plain}\nPEFT load error: {e_peft}"
)
# Build pipeline
pipe = pipeline(
"text-generation",
model=_model,
tokenizer=tokenizer,
device=device,
)
# Prompt shape (keep if you rely on INPUT/OUTPUT markers; otherwise switch to 'Check-in: ')
PREFIX = "INPUT: "
SUFFIX = "\nOUTPUT:"
def make_prompt(user_input: str) -> str:
return f"{PREFIX}{user_input}{SUFFIX}"
class GenerateRequest(BaseModel):
input: str = Field(..., min_length=1)
max_new_tokens: int = 180
temperature: float = 0.7
top_p: float = 0.95
top_k: int = 50
repetition_penalty: float = 1.05
do_sample: bool = True
num_return_sequences: int = 1
class GenerateResponse(BaseModel):
output: str
prompt: str
parameters: Dict[str, Any]
@app.get("/")
def root():
return {
"message": "Check-in GPT-2 API. POST /generate",
"model": MODEL_ID,
"device": "cuda" if device == 0 else "cpu",
"merged_lora": _merged,
}
@app.get("/health")
def health():
return {"status": "ok"}
@app.post("/generate", response_model=GenerateResponse)
def generate(req: GenerateRequest):
try:
prompt = make_prompt(req.input)
gen = pipe(
prompt,
max_new_tokens=req.max_new_tokens,
temperature=req.temperature,
top_p=req.top_p,
top_k=req.top_k,
repetition_penalty=req.repetition_penalty,
do_sample=req.do_sample,
num_return_sequences=req.num_return_sequences,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
return_full_text=True,
)
text = gen[0]["generated_text"]
output = text.split("OUTPUT:", 1)[-1].strip()
return GenerateResponse(output=output, prompt=prompt, parameters=req.model_dump())
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))