File size: 5,165 Bytes
c7d9782
93a3159
c7d9782
 
 
 
93a3159
c7d9782
 
93a3159
 
 
 
 
c7d9782
93a3159
c7d9782
 
 
 
 
 
2f012f6
adeaf8c
93a3159
 
 
2f012f6
93a3159
2f012f6
 
 
93a3159
 
2f012f6
 
 
 
93a3159
 
 
 
 
c7d9782
93a3159
 
 
 
1bb585d
 
93a3159
1bb585d
93a3159
 
 
 
 
 
 
 
 
 
 
 
1bb585d
93a3159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1bb585d
c7d9782
 
2f012f6
c7d9782
1bb585d
c7d9782
 
93a3159
c7d9782
 
 
 
 
 
adeaf8c
c7d9782
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1bb585d
 
93a3159
1bb585d
93a3159
2f012f6
93a3159
 
 
1bb585d
c7d9782
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1bb585d
c7d9782
 
 
 
 
 
93a3159
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import os
from typing import Dict, Any, Optional, Tuple
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from huggingface_hub.utils import RepositoryNotFoundError
import torch

# ---- Config --------------------------------------------------------------
PREFERRED_ID = os.getenv("MODEL_ID", "ethnmcl/checkin-lora-gpt2")
FALLBACK_IDS = ["ethnmcl/checkin-lora-gpt2", "distilgpt2"]  # last-resort keeps API alive
BASE_TOKENIZER = os.getenv("BASE_TOKENIZER", "gpt2")
HF_TOKEN = os.getenv("HF_TOKEN")

app = FastAPI(title="Check-in GPT-2 API", version="1.3.0")
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"],
)

device = 0 if torch.cuda.is_available() else -1
DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32

# ---- Helpers -------------------------------------------------------------
def _load_tokenizer(repo_id: str) -> Tuple:
    """Try repo tokenizer, then fallback to base tokenizer."""
    try:
        tk = AutoTokenizer.from_pretrained(repo_id, token=HF_TOKEN)
        if tk.pad_token is None:
            tk.pad_token = tk.eos_token
        return tk, repo_id, False
    except Exception:
        tk = AutoTokenizer.from_pretrained(BASE_TOKENIZER, token=HF_TOKEN)
        if tk.pad_token is None:
            tk.pad_token = tk.eos_token
        return tk, BASE_TOKENIZER, True

def _try_plain(repo_id: str):
    return AutoModelForCausalLM.from_pretrained(
        repo_id, token=HF_TOKEN, dtype=DTYPE,
        device_map="auto" if torch.cuda.is_available() else None,
    )

def _try_peft(repo_id: str):
    from peft import AutoPeftModelForCausalLM
    m = AutoPeftModelForCausalLM.from_pretrained(
        repo_id, token=HF_TOKEN, dtype=DTYPE,
        device_map="auto" if torch.cuda.is_available() else None,
    )
    # Merge if available; ok if not
    try:
        m = m.merge_and_unload()
        merged = True
    except Exception:
        merged = False
    return m, merged

def load_model_any(repo_id: str):
    """Try plain, then PEFT; raise if both fail."""
    try:
        m = _try_plain(repo_id)
        return m, False
    except Exception as e_plain:
        try:
            m, merged = _try_peft(repo_id)
            return m, merged
        except Exception as e_peft:
            raise RuntimeError(f"load failed for {repo_id} | plain: {e_plain} | peft: {e_peft}")

# ---- Boot: try MODEL_ID first, then fallbacks ----------------------------
errors = {}
chosen_id: Optional[str] = None
merged_lora = False

trial_ids = [PREFERRED_ID] + [i for i in FALLBACK_IDS if i != PREFERRED_ID]
for rid in trial_ids:
    try:
        tokenizer, tokenizer_source, tokenizer_fallback_used = _load_tokenizer(rid)
        model, merged_lora = load_model_any(rid)
        chosen_id = rid
        break
    except Exception as e:
        errors[rid] = str(e)

if chosen_id is None:
    raise RuntimeError(f"All model loads failed. Errors: {errors}")

pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    device=device,
)

# ---- Prompting -----------------------------------------------------------
PREFIX = "INPUT: "
SUFFIX = "\nOUTPUT:"
def make_prompt(user_input: str) -> str:
    return f"{PREFIX}{user_input}{SUFFIX}"

class GenerateRequest(BaseModel):
    input: str = Field(..., min_length=1)
    max_new_tokens: int = 180
    temperature: float = 0.7
    top_p: float = 0.95
    top_k: int = 50
    repetition_penalty: float = 1.05
    do_sample: bool = True
    num_return_sequences: int = 1

class GenerateResponse(BaseModel):
    output: str
    prompt: str
    parameters: Dict[str, Any]

@app.get("/")
def root():
    return {
        "message": "Check-in GPT-2 API. POST /generate",
        "model_chosen": chosen_id,
        "device": "cuda" if device == 0 else "cpu",
        "merged_lora": merged_lora,
        "tokenizer_source": tokenizer_source,
        "tokenizer_fallback_used": tokenizer_fallback_used,
        "attempt_errors": errors,
        "env_MODEL_ID": PREFERRED_ID,
    }

@app.get("/health")
def health():
    return {"status": "ok"}

@app.post("/generate", response_model=GenerateResponse)
def generate(req: GenerateRequest):
    try:
        prompt = make_prompt(req.input)
        gen = pipe(
            prompt,
            max_new_tokens=req.max_new_tokens,
            temperature=req.temperature,
            top_p=req.top_p,
            top_k=req.top_k,
            repetition_penalty=req.repetition_penalty,
            do_sample=req.do_sample,
            num_return_sequences=req.num_return_sequences,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            return_full_text=True,
        )
        text = gen[0]["generated_text"]
        output = text.split("OUTPUT:", 1)[-1].strip()
        return GenerateResponse(output=output, prompt=prompt, parameters=req.model_dump())
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))