import os from typing import Dict, Any, Optional, Tuple from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline from huggingface_hub.utils import RepositoryNotFoundError import torch # ---- Config -------------------------------------------------------------- PREFERRED_ID = os.getenv("MODEL_ID", "ethnmcl/checkin-lora-gpt2") FALLBACK_IDS = ["ethnmcl/checkin-lora-gpt2", "distilgpt2"] # last-resort keeps API alive BASE_TOKENIZER = os.getenv("BASE_TOKENIZER", "gpt2") HF_TOKEN = os.getenv("HF_TOKEN") app = FastAPI(title="Check-in GPT-2 API", version="1.3.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) device = 0 if torch.cuda.is_available() else -1 DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32 # ---- Helpers ------------------------------------------------------------- def _load_tokenizer(repo_id: str) -> Tuple: """Try repo tokenizer, then fallback to base tokenizer.""" try: tk = AutoTokenizer.from_pretrained(repo_id, token=HF_TOKEN) if tk.pad_token is None: tk.pad_token = tk.eos_token return tk, repo_id, False except Exception: tk = AutoTokenizer.from_pretrained(BASE_TOKENIZER, token=HF_TOKEN) if tk.pad_token is None: tk.pad_token = tk.eos_token return tk, BASE_TOKENIZER, True def _try_plain(repo_id: str): return AutoModelForCausalLM.from_pretrained( repo_id, token=HF_TOKEN, dtype=DTYPE, device_map="auto" if torch.cuda.is_available() else None, ) def _try_peft(repo_id: str): from peft import AutoPeftModelForCausalLM m = AutoPeftModelForCausalLM.from_pretrained( repo_id, token=HF_TOKEN, dtype=DTYPE, device_map="auto" if torch.cuda.is_available() else None, ) # Merge if available; ok if not try: m = m.merge_and_unload() merged = True except Exception: merged = False return m, merged def load_model_any(repo_id: str): """Try plain, then PEFT; raise if both fail.""" try: m = _try_plain(repo_id) return m, False except Exception as e_plain: try: m, merged = _try_peft(repo_id) return m, merged except Exception as e_peft: raise RuntimeError(f"load failed for {repo_id} | plain: {e_plain} | peft: {e_peft}") # ---- Boot: try MODEL_ID first, then fallbacks ---------------------------- errors = {} chosen_id: Optional[str] = None merged_lora = False trial_ids = [PREFERRED_ID] + [i for i in FALLBACK_IDS if i != PREFERRED_ID] for rid in trial_ids: try: tokenizer, tokenizer_source, tokenizer_fallback_used = _load_tokenizer(rid) model, merged_lora = load_model_any(rid) chosen_id = rid break except Exception as e: errors[rid] = str(e) if chosen_id is None: raise RuntimeError(f"All model loads failed. Errors: {errors}") pipe = pipeline( "text-generation", model=model, tokenizer=tokenizer, device=device, ) # ---- Prompting ----------------------------------------------------------- PREFIX = "INPUT: " SUFFIX = "\nOUTPUT:" def make_prompt(user_input: str) -> str: return f"{PREFIX}{user_input}{SUFFIX}" class GenerateRequest(BaseModel): input: str = Field(..., min_length=1) max_new_tokens: int = 180 temperature: float = 0.7 top_p: float = 0.95 top_k: int = 50 repetition_penalty: float = 1.05 do_sample: bool = True num_return_sequences: int = 1 class GenerateResponse(BaseModel): output: str prompt: str parameters: Dict[str, Any] @app.get("/") def root(): return { "message": "Check-in GPT-2 API. POST /generate", "model_chosen": chosen_id, "device": "cuda" if device == 0 else "cpu", "merged_lora": merged_lora, "tokenizer_source": tokenizer_source, "tokenizer_fallback_used": tokenizer_fallback_used, "attempt_errors": errors, "env_MODEL_ID": PREFERRED_ID, } @app.get("/health") def health(): return {"status": "ok"} @app.post("/generate", response_model=GenerateResponse) def generate(req: GenerateRequest): try: prompt = make_prompt(req.input) gen = pipe( prompt, max_new_tokens=req.max_new_tokens, temperature=req.temperature, top_p=req.top_p, top_k=req.top_k, repetition_penalty=req.repetition_penalty, do_sample=req.do_sample, num_return_sequences=req.num_return_sequences, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, return_full_text=True, ) text = gen[0]["generated_text"] output = text.split("OUTPUT:", 1)[-1].strip() return GenerateResponse(output=output, prompt=prompt, parameters=req.model_dump()) except Exception as e: raise HTTPException(status_code=500, detail=str(e))