Spaces:
Sleeping
Sleeping
File size: 5,165 Bytes
c7d9782 93a3159 c7d9782 93a3159 c7d9782 93a3159 c7d9782 93a3159 c7d9782 2f012f6 adeaf8c 93a3159 2f012f6 93a3159 2f012f6 93a3159 2f012f6 93a3159 c7d9782 93a3159 1bb585d 93a3159 1bb585d 93a3159 1bb585d 93a3159 1bb585d c7d9782 2f012f6 c7d9782 1bb585d c7d9782 93a3159 c7d9782 adeaf8c c7d9782 1bb585d 93a3159 1bb585d 93a3159 2f012f6 93a3159 1bb585d c7d9782 1bb585d c7d9782 93a3159 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import os
from typing import Dict, Any, Optional, Tuple
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from huggingface_hub.utils import RepositoryNotFoundError
import torch
# ---- Config --------------------------------------------------------------
PREFERRED_ID = os.getenv("MODEL_ID", "ethnmcl/checkin-lora-gpt2")
FALLBACK_IDS = ["ethnmcl/checkin-lora-gpt2", "distilgpt2"] # last-resort keeps API alive
BASE_TOKENIZER = os.getenv("BASE_TOKENIZER", "gpt2")
HF_TOKEN = os.getenv("HF_TOKEN")
app = FastAPI(title="Check-in GPT-2 API", version="1.3.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"],
)
device = 0 if torch.cuda.is_available() else -1
DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
# ---- Helpers -------------------------------------------------------------
def _load_tokenizer(repo_id: str) -> Tuple:
"""Try repo tokenizer, then fallback to base tokenizer."""
try:
tk = AutoTokenizer.from_pretrained(repo_id, token=HF_TOKEN)
if tk.pad_token is None:
tk.pad_token = tk.eos_token
return tk, repo_id, False
except Exception:
tk = AutoTokenizer.from_pretrained(BASE_TOKENIZER, token=HF_TOKEN)
if tk.pad_token is None:
tk.pad_token = tk.eos_token
return tk, BASE_TOKENIZER, True
def _try_plain(repo_id: str):
return AutoModelForCausalLM.from_pretrained(
repo_id, token=HF_TOKEN, dtype=DTYPE,
device_map="auto" if torch.cuda.is_available() else None,
)
def _try_peft(repo_id: str):
from peft import AutoPeftModelForCausalLM
m = AutoPeftModelForCausalLM.from_pretrained(
repo_id, token=HF_TOKEN, dtype=DTYPE,
device_map="auto" if torch.cuda.is_available() else None,
)
# Merge if available; ok if not
try:
m = m.merge_and_unload()
merged = True
except Exception:
merged = False
return m, merged
def load_model_any(repo_id: str):
"""Try plain, then PEFT; raise if both fail."""
try:
m = _try_plain(repo_id)
return m, False
except Exception as e_plain:
try:
m, merged = _try_peft(repo_id)
return m, merged
except Exception as e_peft:
raise RuntimeError(f"load failed for {repo_id} | plain: {e_plain} | peft: {e_peft}")
# ---- Boot: try MODEL_ID first, then fallbacks ----------------------------
errors = {}
chosen_id: Optional[str] = None
merged_lora = False
trial_ids = [PREFERRED_ID] + [i for i in FALLBACK_IDS if i != PREFERRED_ID]
for rid in trial_ids:
try:
tokenizer, tokenizer_source, tokenizer_fallback_used = _load_tokenizer(rid)
model, merged_lora = load_model_any(rid)
chosen_id = rid
break
except Exception as e:
errors[rid] = str(e)
if chosen_id is None:
raise RuntimeError(f"All model loads failed. Errors: {errors}")
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device=device,
)
# ---- Prompting -----------------------------------------------------------
PREFIX = "INPUT: "
SUFFIX = "\nOUTPUT:"
def make_prompt(user_input: str) -> str:
return f"{PREFIX}{user_input}{SUFFIX}"
class GenerateRequest(BaseModel):
input: str = Field(..., min_length=1)
max_new_tokens: int = 180
temperature: float = 0.7
top_p: float = 0.95
top_k: int = 50
repetition_penalty: float = 1.05
do_sample: bool = True
num_return_sequences: int = 1
class GenerateResponse(BaseModel):
output: str
prompt: str
parameters: Dict[str, Any]
@app.get("/")
def root():
return {
"message": "Check-in GPT-2 API. POST /generate",
"model_chosen": chosen_id,
"device": "cuda" if device == 0 else "cpu",
"merged_lora": merged_lora,
"tokenizer_source": tokenizer_source,
"tokenizer_fallback_used": tokenizer_fallback_used,
"attempt_errors": errors,
"env_MODEL_ID": PREFERRED_ID,
}
@app.get("/health")
def health():
return {"status": "ok"}
@app.post("/generate", response_model=GenerateResponse)
def generate(req: GenerateRequest):
try:
prompt = make_prompt(req.input)
gen = pipe(
prompt,
max_new_tokens=req.max_new_tokens,
temperature=req.temperature,
top_p=req.top_p,
top_k=req.top_k,
repetition_penalty=req.repetition_penalty,
do_sample=req.do_sample,
num_return_sequences=req.num_return_sequences,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
return_full_text=True,
)
text = gen[0]["generated_text"]
output = text.split("OUTPUT:", 1)[-1].strip()
return GenerateResponse(output=output, prompt=prompt, parameters=req.model_dump())
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
|