CheckInAPI / main.py
ethnmcl's picture
Update main.py
93a3159 verified
import os
from typing import Dict, Any, Optional, Tuple
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from huggingface_hub.utils import RepositoryNotFoundError
import torch
# ---- Config --------------------------------------------------------------
PREFERRED_ID = os.getenv("MODEL_ID", "ethnmcl/checkin-lora-gpt2")
FALLBACK_IDS = ["ethnmcl/checkin-lora-gpt2", "distilgpt2"] # last-resort keeps API alive
BASE_TOKENIZER = os.getenv("BASE_TOKENIZER", "gpt2")
HF_TOKEN = os.getenv("HF_TOKEN")
app = FastAPI(title="Check-in GPT-2 API", version="1.3.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"],
)
device = 0 if torch.cuda.is_available() else -1
DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
# ---- Helpers -------------------------------------------------------------
def _load_tokenizer(repo_id: str) -> Tuple:
"""Try repo tokenizer, then fallback to base tokenizer."""
try:
tk = AutoTokenizer.from_pretrained(repo_id, token=HF_TOKEN)
if tk.pad_token is None:
tk.pad_token = tk.eos_token
return tk, repo_id, False
except Exception:
tk = AutoTokenizer.from_pretrained(BASE_TOKENIZER, token=HF_TOKEN)
if tk.pad_token is None:
tk.pad_token = tk.eos_token
return tk, BASE_TOKENIZER, True
def _try_plain(repo_id: str):
return AutoModelForCausalLM.from_pretrained(
repo_id, token=HF_TOKEN, dtype=DTYPE,
device_map="auto" if torch.cuda.is_available() else None,
)
def _try_peft(repo_id: str):
from peft import AutoPeftModelForCausalLM
m = AutoPeftModelForCausalLM.from_pretrained(
repo_id, token=HF_TOKEN, dtype=DTYPE,
device_map="auto" if torch.cuda.is_available() else None,
)
# Merge if available; ok if not
try:
m = m.merge_and_unload()
merged = True
except Exception:
merged = False
return m, merged
def load_model_any(repo_id: str):
"""Try plain, then PEFT; raise if both fail."""
try:
m = _try_plain(repo_id)
return m, False
except Exception as e_plain:
try:
m, merged = _try_peft(repo_id)
return m, merged
except Exception as e_peft:
raise RuntimeError(f"load failed for {repo_id} | plain: {e_plain} | peft: {e_peft}")
# ---- Boot: try MODEL_ID first, then fallbacks ----------------------------
errors = {}
chosen_id: Optional[str] = None
merged_lora = False
trial_ids = [PREFERRED_ID] + [i for i in FALLBACK_IDS if i != PREFERRED_ID]
for rid in trial_ids:
try:
tokenizer, tokenizer_source, tokenizer_fallback_used = _load_tokenizer(rid)
model, merged_lora = load_model_any(rid)
chosen_id = rid
break
except Exception as e:
errors[rid] = str(e)
if chosen_id is None:
raise RuntimeError(f"All model loads failed. Errors: {errors}")
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device=device,
)
# ---- Prompting -----------------------------------------------------------
PREFIX = "INPUT: "
SUFFIX = "\nOUTPUT:"
def make_prompt(user_input: str) -> str:
return f"{PREFIX}{user_input}{SUFFIX}"
class GenerateRequest(BaseModel):
input: str = Field(..., min_length=1)
max_new_tokens: int = 180
temperature: float = 0.7
top_p: float = 0.95
top_k: int = 50
repetition_penalty: float = 1.05
do_sample: bool = True
num_return_sequences: int = 1
class GenerateResponse(BaseModel):
output: str
prompt: str
parameters: Dict[str, Any]
@app.get("/")
def root():
return {
"message": "Check-in GPT-2 API. POST /generate",
"model_chosen": chosen_id,
"device": "cuda" if device == 0 else "cpu",
"merged_lora": merged_lora,
"tokenizer_source": tokenizer_source,
"tokenizer_fallback_used": tokenizer_fallback_used,
"attempt_errors": errors,
"env_MODEL_ID": PREFERRED_ID,
}
@app.get("/health")
def health():
return {"status": "ok"}
@app.post("/generate", response_model=GenerateResponse)
def generate(req: GenerateRequest):
try:
prompt = make_prompt(req.input)
gen = pipe(
prompt,
max_new_tokens=req.max_new_tokens,
temperature=req.temperature,
top_p=req.top_p,
top_k=req.top_k,
repetition_penalty=req.repetition_penalty,
do_sample=req.do_sample,
num_return_sequences=req.num_return_sequences,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
return_full_text=True,
)
text = gen[0]["generated_text"]
output = text.split("OUTPUT:", 1)[-1].strip()
return GenerateResponse(output=output, prompt=prompt, parameters=req.model_dump())
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))