import os from typing import Dict, Any from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline import torch # === Config === MODEL_ID = os.getenv("MODEL_ID", "ethnmcl/checkin-lora-gpt2") HF_TOKEN = os.getenv("HF_TOKEN") # if the repo is private, set this in Secrets app = FastAPI(title="Check-in GPT-2 API", version="1.1.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Choose device: GPU index 0 if available else CPU device = 0 if torch.cuda.is_available() else -1 # === Load tokenizer === tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # === Load model (supports plain CausalLM repos AND PEFT LoRA adapters) === # Strategy: # 1) Try plain AutoModelForCausalLM # 2) If that fails (likely LoRA-only repo), try PEFT AutoPeftModelForCausalLM and merge _model = None _merged = False try: _model = AutoModelForCausalLM.from_pretrained( MODEL_ID, token=HF_TOKEN, # Use 'dtype' not deprecated 'torch_dtype' dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto" if torch.cuda.is_available() else None, ) except Exception as e_plain: # Fall back to PEFT path try: from peft import AutoPeftModelForCausalLM _model = AutoPeftModelForCausalLM.from_pretrained( MODEL_ID, token=HF_TOKEN, dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto" if torch.cuda.is_available() else None, ) # Merge LoRA into base weights so inference behaves like a standard CausalLM try: _model = _model.merge_and_unload() _merged = True except Exception: # If merge not available, we still can run with adapters active _merged = False except Exception as e_peft: raise RuntimeError( f"Failed to load model '{MODEL_ID}'. " f"Plain load error: {e_plain}\nPEFT load error: {e_peft}" ) # Build pipeline pipe = pipeline( "text-generation", model=_model, tokenizer=tokenizer, device=device, ) # Prompt shape (keep if you rely on INPUT/OUTPUT markers; otherwise switch to 'Check-in: ') PREFIX = "INPUT: " SUFFIX = "\nOUTPUT:" def make_prompt(user_input: str) -> str: return f"{PREFIX}{user_input}{SUFFIX}" class GenerateRequest(BaseModel): input: str = Field(..., min_length=1) max_new_tokens: int = 180 temperature: float = 0.7 top_p: float = 0.95 top_k: int = 50 repetition_penalty: float = 1.05 do_sample: bool = True num_return_sequences: int = 1 class GenerateResponse(BaseModel): output: str prompt: str parameters: Dict[str, Any] @app.get("/") def root(): return { "message": "Check-in GPT-2 API. POST /generate", "model": MODEL_ID, "device": "cuda" if device == 0 else "cpu", "merged_lora": _merged, } @app.get("/health") def health(): return {"status": "ok"} @app.post("/generate", response_model=GenerateResponse) def generate(req: GenerateRequest): try: prompt = make_prompt(req.input) gen = pipe( prompt, max_new_tokens=req.max_new_tokens, temperature=req.temperature, top_p=req.top_p, top_k=req.top_k, repetition_penalty=req.repetition_penalty, do_sample=req.do_sample, num_return_sequences=req.num_return_sequences, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, return_full_text=True, ) text = gen[0]["generated_text"] output = text.split("OUTPUT:", 1)[-1].strip() return GenerateResponse(output=output, prompt=prompt, parameters=req.model_dump()) except Exception as e: raise HTTPException(status_code=500, detail=str(e))