ethnmcl commited on
Commit
1bb585d
·
verified ·
1 Parent(s): adeaf8c

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +54 -11
main.py CHANGED
@@ -6,34 +6,72 @@ from pydantic import BaseModel, Field
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
7
  import torch
8
 
9
- MODEL_ID = os.getenv("MODEL_ID", "ethnmcl/checkin-gpt2")
10
- HF_TOKEN = os.getenv("HF_TOKEN") # set in Space Secrets if the model repo is private
11
-
12
- app = FastAPI(title="Check-in GPT-2 API", version="1.0.0")
13
 
 
14
  app.add_middleware(
15
  CORSMiddleware,
16
  allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"],
17
  )
18
 
 
19
  device = 0 if torch.cuda.is_available() else -1
20
 
21
- # use token= (not use_auth_token) and rely on HF_HOME=/data/huggingface
22
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
23
  if tokenizer.pad_token is None:
24
  tokenizer.pad_token = tokenizer.eos_token
25
- model = AutoModelForCausalLM.from_pretrained(MODEL_ID, token=HF_TOKEN)
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  pipe = pipeline(
28
  "text-generation",
29
- model=model,
30
  tokenizer=tokenizer,
31
- device=device
32
  )
33
 
 
34
  PREFIX = "INPUT: "
35
  SUFFIX = "\nOUTPUT:"
36
-
37
  def make_prompt(user_input: str) -> str:
38
  return f"{PREFIX}{user_input}{SUFFIX}"
39
 
@@ -54,7 +92,12 @@ class GenerateResponse(BaseModel):
54
 
55
  @app.get("/")
56
  def root():
57
- return {"message": "Check-in GPT-2 API. POST /generate", "model": MODEL_ID, "device": "cuda" if device == 0 else "cpu"}
 
 
 
 
 
58
 
59
  @app.get("/health")
60
  def health():
@@ -75,7 +118,7 @@ def generate(req: GenerateRequest):
75
  num_return_sequences=req.num_return_sequences,
76
  pad_token_id=tokenizer.eos_token_id,
77
  eos_token_id=tokenizer.eos_token_id,
78
- return_full_text=True
79
  )
80
  text = gen[0]["generated_text"]
81
  output = text.split("OUTPUT:", 1)[-1].strip()
 
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
7
  import torch
8
 
9
+ # === Config ===
10
+ MODEL_ID = os.getenv("MODEL_ID", "ethnmcl/checkin-lora-gpt2")
11
+ HF_TOKEN = os.getenv("HF_TOKEN") # if the repo is private, set this in Secrets
 
12
 
13
+ app = FastAPI(title="Check-in GPT-2 API", version="1.1.0")
14
  app.add_middleware(
15
  CORSMiddleware,
16
  allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"],
17
  )
18
 
19
+ # Choose device: GPU index 0 if available else CPU
20
  device = 0 if torch.cuda.is_available() else -1
21
 
22
+ # === Load tokenizer ===
23
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
24
  if tokenizer.pad_token is None:
25
  tokenizer.pad_token = tokenizer.eos_token
 
26
 
27
+ # === Load model (supports plain CausalLM repos AND PEFT LoRA adapters) ===
28
+ # Strategy:
29
+ # 1) Try plain AutoModelForCausalLM
30
+ # 2) If that fails (likely LoRA-only repo), try PEFT AutoPeftModelForCausalLM and merge
31
+ _model = None
32
+ _merged = False
33
+ try:
34
+ _model = AutoModelForCausalLM.from_pretrained(
35
+ MODEL_ID,
36
+ token=HF_TOKEN,
37
+ # Use 'dtype' not deprecated 'torch_dtype'
38
+ dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
39
+ device_map="auto" if torch.cuda.is_available() else None,
40
+ )
41
+ except Exception as e_plain:
42
+ # Fall back to PEFT path
43
+ try:
44
+ from peft import AutoPeftModelForCausalLM
45
+ _model = AutoPeftModelForCausalLM.from_pretrained(
46
+ MODEL_ID,
47
+ token=HF_TOKEN,
48
+ dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
49
+ device_map="auto" if torch.cuda.is_available() else None,
50
+ )
51
+ # Merge LoRA into base weights so inference behaves like a standard CausalLM
52
+ try:
53
+ _model = _model.merge_and_unload()
54
+ _merged = True
55
+ except Exception:
56
+ # If merge not available, we still can run with adapters active
57
+ _merged = False
58
+ except Exception as e_peft:
59
+ raise RuntimeError(
60
+ f"Failed to load model '{MODEL_ID}'. "
61
+ f"Plain load error: {e_plain}\nPEFT load error: {e_peft}"
62
+ )
63
+
64
+ # Build pipeline
65
  pipe = pipeline(
66
  "text-generation",
67
+ model=_model,
68
  tokenizer=tokenizer,
69
+ device=device,
70
  )
71
 
72
+ # Prompt shape (keep if you rely on INPUT/OUTPUT markers; otherwise switch to 'Check-in: ')
73
  PREFIX = "INPUT: "
74
  SUFFIX = "\nOUTPUT:"
 
75
  def make_prompt(user_input: str) -> str:
76
  return f"{PREFIX}{user_input}{SUFFIX}"
77
 
 
92
 
93
  @app.get("/")
94
  def root():
95
+ return {
96
+ "message": "Check-in GPT-2 API. POST /generate",
97
+ "model": MODEL_ID,
98
+ "device": "cuda" if device == 0 else "cpu",
99
+ "merged_lora": _merged,
100
+ }
101
 
102
  @app.get("/health")
103
  def health():
 
118
  num_return_sequences=req.num_return_sequences,
119
  pad_token_id=tokenizer.eos_token_id,
120
  eos_token_id=tokenizer.eos_token_id,
121
+ return_full_text=True,
122
  )
123
  text = gen[0]["generated_text"]
124
  output = text.split("OUTPUT:", 1)[-1].strip()