Spaces:
Sleeping
Sleeping
File size: 4,193 Bytes
c7d9782 1bb585d c7d9782 1bb585d c7d9782 1bb585d c7d9782 adeaf8c 1bb585d adeaf8c c7d9782 1bb585d c7d9782 1bb585d c7d9782 1bb585d c7d9782 1bb585d c7d9782 adeaf8c c7d9782 1bb585d c7d9782 1bb585d c7d9782 adeaf8c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 | import os
from typing import Dict, Any
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
# === Config ===
MODEL_ID = os.getenv("MODEL_ID", "ethnmcl/checkin-lora-gpt2")
HF_TOKEN = os.getenv("HF_TOKEN") # if the repo is private, set this in Secrets
app = FastAPI(title="Check-in GPT-2 API", version="1.1.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"],
)
# Choose device: GPU index 0 if available else CPU
device = 0 if torch.cuda.is_available() else -1
# === Load tokenizer ===
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# === Load model (supports plain CausalLM repos AND PEFT LoRA adapters) ===
# Strategy:
# 1) Try plain AutoModelForCausalLM
# 2) If that fails (likely LoRA-only repo), try PEFT AutoPeftModelForCausalLM and merge
_model = None
_merged = False
try:
_model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
token=HF_TOKEN,
# Use 'dtype' not deprecated 'torch_dtype'
dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
)
except Exception as e_plain:
# Fall back to PEFT path
try:
from peft import AutoPeftModelForCausalLM
_model = AutoPeftModelForCausalLM.from_pretrained(
MODEL_ID,
token=HF_TOKEN,
dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
)
# Merge LoRA into base weights so inference behaves like a standard CausalLM
try:
_model = _model.merge_and_unload()
_merged = True
except Exception:
# If merge not available, we still can run with adapters active
_merged = False
except Exception as e_peft:
raise RuntimeError(
f"Failed to load model '{MODEL_ID}'. "
f"Plain load error: {e_plain}\nPEFT load error: {e_peft}"
)
# Build pipeline
pipe = pipeline(
"text-generation",
model=_model,
tokenizer=tokenizer,
device=device,
)
# Prompt shape (keep if you rely on INPUT/OUTPUT markers; otherwise switch to 'Check-in: ')
PREFIX = "INPUT: "
SUFFIX = "\nOUTPUT:"
def make_prompt(user_input: str) -> str:
return f"{PREFIX}{user_input}{SUFFIX}"
class GenerateRequest(BaseModel):
input: str = Field(..., min_length=1)
max_new_tokens: int = 180
temperature: float = 0.7
top_p: float = 0.95
top_k: int = 50
repetition_penalty: float = 1.05
do_sample: bool = True
num_return_sequences: int = 1
class GenerateResponse(BaseModel):
output: str
prompt: str
parameters: Dict[str, Any]
@app.get("/")
def root():
return {
"message": "Check-in GPT-2 API. POST /generate",
"model": MODEL_ID,
"device": "cuda" if device == 0 else "cpu",
"merged_lora": _merged,
}
@app.get("/health")
def health():
return {"status": "ok"}
@app.post("/generate", response_model=GenerateResponse)
def generate(req: GenerateRequest):
try:
prompt = make_prompt(req.input)
gen = pipe(
prompt,
max_new_tokens=req.max_new_tokens,
temperature=req.temperature,
top_p=req.top_p,
top_k=req.top_k,
repetition_penalty=req.repetition_penalty,
do_sample=req.do_sample,
num_return_sequences=req.num_return_sequences,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
return_full_text=True,
)
text = gen[0]["generated_text"]
output = text.split("OUTPUT:", 1)[-1].strip()
return GenerateResponse(output=output, prompt=prompt, parameters=req.model_dump())
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
|