ethnmcl commited on
Commit
2f012f6
·
verified ·
1 Parent(s): 4fae63b

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +32 -27
main.py CHANGED
@@ -6,54 +6,59 @@ from pydantic import BaseModel, Field
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
7
  import torch
8
 
9
- # === Config ===
10
- MODEL_ID = os.getenv("MODEL_ID", "ethnmcl/checkin-lora-gpt2")
11
- HF_TOKEN = os.getenv("HF_TOKEN") # if the repo is private, set this in Secrets
 
12
 
13
- app = FastAPI(title="Check-in GPT-2 API", version="1.1.0")
14
  app.add_middleware(
15
  CORSMiddleware,
16
  allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"],
17
  )
18
 
19
- # Choose device: GPU index 0 if available else CPU
20
  device = 0 if torch.cuda.is_available() else -1
 
21
 
22
- # === Load tokenizer ===
23
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
24
- if tokenizer.pad_token is None:
25
- tokenizer.pad_token = tokenizer.eos_token
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- # === Load model (supports plain CausalLM repos AND PEFT LoRA adapters) ===
28
- # Strategy:
29
- # 1) Try plain AutoModelForCausalLM
30
- # 2) If that fails (likely LoRA-only repo), try PEFT AutoPeftModelForCausalLM and merge
31
- _model = None
32
  _merged = False
33
  try:
34
- _model = AutoModelForCausalLM.from_pretrained(
35
  MODEL_ID,
36
  token=HF_TOKEN,
37
- # Use 'dtype' not deprecated 'torch_dtype'
38
- dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
39
  device_map="auto" if torch.cuda.is_available() else None,
40
  )
41
  except Exception as e_plain:
42
- # Fall back to PEFT path
43
  try:
44
  from peft import AutoPeftModelForCausalLM
45
- _model = AutoPeftModelForCausalLM.from_pretrained(
46
  MODEL_ID,
47
  token=HF_TOKEN,
48
- dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
49
  device_map="auto" if torch.cuda.is_available() else None,
50
  )
51
- # Merge LoRA into base weights so inference behaves like a standard CausalLM
52
  try:
53
- _model = _model.merge_and_unload()
54
  _merged = True
55
  except Exception:
56
- # If merge not available, we still can run with adapters active
57
  _merged = False
58
  except Exception as e_peft:
59
  raise RuntimeError(
@@ -61,15 +66,14 @@ except Exception as e_plain:
61
  f"Plain load error: {e_plain}\nPEFT load error: {e_peft}"
62
  )
63
 
64
- # Build pipeline
65
  pipe = pipeline(
66
  "text-generation",
67
- model=_model,
68
  tokenizer=tokenizer,
69
  device=device,
70
  )
71
 
72
- # Prompt shape (keep if you rely on INPUT/OUTPUT markers; otherwise switch to 'Check-in: ')
73
  PREFIX = "INPUT: "
74
  SUFFIX = "\nOUTPUT:"
75
  def make_prompt(user_input: str) -> str:
@@ -97,6 +101,8 @@ def root():
97
  "model": MODEL_ID,
98
  "device": "cuda" if device == 0 else "cpu",
99
  "merged_lora": _merged,
 
 
100
  }
101
 
102
  @app.get("/health")
@@ -125,4 +131,3 @@ def generate(req: GenerateRequest):
125
  return GenerateResponse(output=output, prompt=prompt, parameters=req.model_dump())
126
  except Exception as e:
127
  raise HTTPException(status_code=500, detail=str(e))
128
-
 
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
7
  import torch
8
 
9
+ # ---- Config -----------------------------------------------------------------
10
+ MODEL_ID = os.getenv("MODEL_ID", "ethnmcl/checkin-lora-gpt2") # NEW default
11
+ BASE_TOKENIZER = os.getenv("BASE_TOKENIZER", "gpt2") # fallback if LoRA repo has no tokenizer
12
+ HF_TOKEN = os.getenv("HF_TOKEN") # set if private
13
 
14
+ app = FastAPI(title="Check-in GPT-2 API", version="1.2.0")
15
  app.add_middleware(
16
  CORSMiddleware,
17
  allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"],
18
  )
19
 
 
20
  device = 0 if torch.cuda.is_available() else -1
21
+ DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
22
 
23
+ # ---- Tokenizer (with fallback for adapter-only repos) ------------------------
24
+ def load_tokenizer(repo_id: str, token: str | None):
25
+ try:
26
+ tk = AutoTokenizer.from_pretrained(repo_id, token=token)
27
+ if tk.pad_token is None:
28
+ tk.pad_token = tk.eos_token
29
+ return tk, repo_id, False
30
+ except Exception as e_model_tok:
31
+ # Adapter repos often don't include tokenizer files: fallback to base tokenizer
32
+ tk = AutoTokenizer.from_pretrained(BASE_TOKENIZER, token=token)
33
+ if tk.pad_token is None:
34
+ tk.pad_token = tk.eos_token
35
+ return tk, BASE_TOKENIZER, True
36
+
37
+ tokenizer, tokenizer_source, tokenizer_fallback = load_tokenizer(MODEL_ID, HF_TOKEN)
38
 
39
+ # ---- Model (plain or PEFT LoRA) ---------------------------------------------
 
 
 
 
40
  _merged = False
41
  try:
42
+ model = AutoModelForCausalLM.from_pretrained(
43
  MODEL_ID,
44
  token=HF_TOKEN,
45
+ dtype=DTYPE,
 
46
  device_map="auto" if torch.cuda.is_available() else None,
47
  )
48
  except Exception as e_plain:
49
+ # Try PEFT (adapter) path
50
  try:
51
  from peft import AutoPeftModelForCausalLM
52
+ model = AutoPeftModelForCausalLM.from_pretrained(
53
  MODEL_ID,
54
  token=HF_TOKEN,
55
+ dtype=DTYPE,
56
  device_map="auto" if torch.cuda.is_available() else None,
57
  )
 
58
  try:
59
+ model = model.merge_and_unload()
60
  _merged = True
61
  except Exception:
 
62
  _merged = False
63
  except Exception as e_peft:
64
  raise RuntimeError(
 
66
  f"Plain load error: {e_plain}\nPEFT load error: {e_peft}"
67
  )
68
 
 
69
  pipe = pipeline(
70
  "text-generation",
71
+ model=model,
72
  tokenizer=tokenizer,
73
  device=device,
74
  )
75
 
76
+ # ---- Prompting ---------------------------------------------------------------
77
  PREFIX = "INPUT: "
78
  SUFFIX = "\nOUTPUT:"
79
  def make_prompt(user_input: str) -> str:
 
101
  "model": MODEL_ID,
102
  "device": "cuda" if device == 0 else "cpu",
103
  "merged_lora": _merged,
104
+ "tokenizer_source": tokenizer_source,
105
+ "tokenizer_fallback_used": tokenizer_fallback,
106
  }
107
 
108
  @app.get("/health")
 
131
  return GenerateResponse(output=output, prompt=prompt, parameters=req.model_dump())
132
  except Exception as e:
133
  raise HTTPException(status_code=500, detail=str(e))