File size: 4,193 Bytes
c7d9782
 
 
 
 
 
 
 
1bb585d
 
 
c7d9782
1bb585d
c7d9782
 
 
 
 
1bb585d
c7d9782
adeaf8c
1bb585d
adeaf8c
c7d9782
 
 
1bb585d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7d9782
 
1bb585d
c7d9782
1bb585d
c7d9782
 
1bb585d
c7d9782
 
 
 
 
 
adeaf8c
c7d9782
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1bb585d
 
 
 
 
 
c7d9782
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1bb585d
c7d9782
 
 
 
 
 
adeaf8c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import os
from typing import Dict, Any
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch

# === Config ===
MODEL_ID = os.getenv("MODEL_ID", "ethnmcl/checkin-lora-gpt2")
HF_TOKEN = os.getenv("HF_TOKEN")  # if the repo is private, set this in Secrets

app = FastAPI(title="Check-in GPT-2 API", version="1.1.0")
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"],
)

# Choose device: GPU index 0 if available else CPU
device = 0 if torch.cuda.is_available() else -1

# === Load tokenizer ===
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# === Load model (supports plain CausalLM repos AND PEFT LoRA adapters) ===
# Strategy:
#   1) Try plain AutoModelForCausalLM
#   2) If that fails (likely LoRA-only repo), try PEFT AutoPeftModelForCausalLM and merge
_model = None
_merged = False
try:
    _model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        token=HF_TOKEN,
        # Use 'dtype' not deprecated 'torch_dtype'
        dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
        device_map="auto" if torch.cuda.is_available() else None,
    )
except Exception as e_plain:
    # Fall back to PEFT path
    try:
        from peft import AutoPeftModelForCausalLM
        _model = AutoPeftModelForCausalLM.from_pretrained(
            MODEL_ID,
            token=HF_TOKEN,
            dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
            device_map="auto" if torch.cuda.is_available() else None,
        )
        # Merge LoRA into base weights so inference behaves like a standard CausalLM
        try:
            _model = _model.merge_and_unload()
            _merged = True
        except Exception:
            # If merge not available, we still can run with adapters active
            _merged = False
    except Exception as e_peft:
        raise RuntimeError(
            f"Failed to load model '{MODEL_ID}'. "
            f"Plain load error: {e_plain}\nPEFT load error: {e_peft}"
        )

# Build pipeline
pipe = pipeline(
    "text-generation",
    model=_model,
    tokenizer=tokenizer,
    device=device,
)

# Prompt shape (keep if you rely on INPUT/OUTPUT markers; otherwise switch to 'Check-in: ')
PREFIX = "INPUT: "
SUFFIX = "\nOUTPUT:"
def make_prompt(user_input: str) -> str:
    return f"{PREFIX}{user_input}{SUFFIX}"

class GenerateRequest(BaseModel):
    input: str = Field(..., min_length=1)
    max_new_tokens: int = 180
    temperature: float = 0.7
    top_p: float = 0.95
    top_k: int = 50
    repetition_penalty: float = 1.05
    do_sample: bool = True
    num_return_sequences: int = 1

class GenerateResponse(BaseModel):
    output: str
    prompt: str
    parameters: Dict[str, Any]

@app.get("/")
def root():
    return {
        "message": "Check-in GPT-2 API. POST /generate",
        "model": MODEL_ID,
        "device": "cuda" if device == 0 else "cpu",
        "merged_lora": _merged,
    }

@app.get("/health")
def health():
    return {"status": "ok"}

@app.post("/generate", response_model=GenerateResponse)
def generate(req: GenerateRequest):
    try:
        prompt = make_prompt(req.input)
        gen = pipe(
            prompt,
            max_new_tokens=req.max_new_tokens,
            temperature=req.temperature,
            top_p=req.top_p,
            top_k=req.top_k,
            repetition_penalty=req.repetition_penalty,
            do_sample=req.do_sample,
            num_return_sequences=req.num_return_sequences,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            return_full_text=True,
        )
        text = gen[0]["generated_text"]
        output = text.split("OUTPUT:", 1)[-1].strip()
        return GenerateResponse(output=output, prompt=prompt, parameters=req.model_dump())
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))