Spaces:
Sleeping
Sleeping
| import os | |
| from typing import Dict, Any, Optional, Tuple | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| from huggingface_hub.utils import RepositoryNotFoundError | |
| import torch | |
| # ---- Config -------------------------------------------------------------- | |
| PREFERRED_ID = os.getenv("MODEL_ID", "ethnmcl/checkin-lora-gpt2") | |
| FALLBACK_IDS = ["ethnmcl/checkin-lora-gpt2", "distilgpt2"] # last-resort keeps API alive | |
| BASE_TOKENIZER = os.getenv("BASE_TOKENIZER", "gpt2") | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| app = FastAPI(title="Check-in GPT-2 API", version="1.3.0") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], | |
| ) | |
| device = 0 if torch.cuda.is_available() else -1 | |
| DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| # ---- Helpers ------------------------------------------------------------- | |
| def _load_tokenizer(repo_id: str) -> Tuple: | |
| """Try repo tokenizer, then fallback to base tokenizer.""" | |
| try: | |
| tk = AutoTokenizer.from_pretrained(repo_id, token=HF_TOKEN) | |
| if tk.pad_token is None: | |
| tk.pad_token = tk.eos_token | |
| return tk, repo_id, False | |
| except Exception: | |
| tk = AutoTokenizer.from_pretrained(BASE_TOKENIZER, token=HF_TOKEN) | |
| if tk.pad_token is None: | |
| tk.pad_token = tk.eos_token | |
| return tk, BASE_TOKENIZER, True | |
| def _try_plain(repo_id: str): | |
| return AutoModelForCausalLM.from_pretrained( | |
| repo_id, token=HF_TOKEN, dtype=DTYPE, | |
| device_map="auto" if torch.cuda.is_available() else None, | |
| ) | |
| def _try_peft(repo_id: str): | |
| from peft import AutoPeftModelForCausalLM | |
| m = AutoPeftModelForCausalLM.from_pretrained( | |
| repo_id, token=HF_TOKEN, dtype=DTYPE, | |
| device_map="auto" if torch.cuda.is_available() else None, | |
| ) | |
| # Merge if available; ok if not | |
| try: | |
| m = m.merge_and_unload() | |
| merged = True | |
| except Exception: | |
| merged = False | |
| return m, merged | |
| def load_model_any(repo_id: str): | |
| """Try plain, then PEFT; raise if both fail.""" | |
| try: | |
| m = _try_plain(repo_id) | |
| return m, False | |
| except Exception as e_plain: | |
| try: | |
| m, merged = _try_peft(repo_id) | |
| return m, merged | |
| except Exception as e_peft: | |
| raise RuntimeError(f"load failed for {repo_id} | plain: {e_plain} | peft: {e_peft}") | |
| # ---- Boot: try MODEL_ID first, then fallbacks ---------------------------- | |
| errors = {} | |
| chosen_id: Optional[str] = None | |
| merged_lora = False | |
| trial_ids = [PREFERRED_ID] + [i for i in FALLBACK_IDS if i != PREFERRED_ID] | |
| for rid in trial_ids: | |
| try: | |
| tokenizer, tokenizer_source, tokenizer_fallback_used = _load_tokenizer(rid) | |
| model, merged_lora = load_model_any(rid) | |
| chosen_id = rid | |
| break | |
| except Exception as e: | |
| errors[rid] = str(e) | |
| if chosen_id is None: | |
| raise RuntimeError(f"All model loads failed. Errors: {errors}") | |
| pipe = pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| device=device, | |
| ) | |
| # ---- Prompting ----------------------------------------------------------- | |
| PREFIX = "INPUT: " | |
| SUFFIX = "\nOUTPUT:" | |
| def make_prompt(user_input: str) -> str: | |
| return f"{PREFIX}{user_input}{SUFFIX}" | |
| class GenerateRequest(BaseModel): | |
| input: str = Field(..., min_length=1) | |
| max_new_tokens: int = 180 | |
| temperature: float = 0.7 | |
| top_p: float = 0.95 | |
| top_k: int = 50 | |
| repetition_penalty: float = 1.05 | |
| do_sample: bool = True | |
| num_return_sequences: int = 1 | |
| class GenerateResponse(BaseModel): | |
| output: str | |
| prompt: str | |
| parameters: Dict[str, Any] | |
| def root(): | |
| return { | |
| "message": "Check-in GPT-2 API. POST /generate", | |
| "model_chosen": chosen_id, | |
| "device": "cuda" if device == 0 else "cpu", | |
| "merged_lora": merged_lora, | |
| "tokenizer_source": tokenizer_source, | |
| "tokenizer_fallback_used": tokenizer_fallback_used, | |
| "attempt_errors": errors, | |
| "env_MODEL_ID": PREFERRED_ID, | |
| } | |
| def health(): | |
| return {"status": "ok"} | |
| def generate(req: GenerateRequest): | |
| try: | |
| prompt = make_prompt(req.input) | |
| gen = pipe( | |
| prompt, | |
| max_new_tokens=req.max_new_tokens, | |
| temperature=req.temperature, | |
| top_p=req.top_p, | |
| top_k=req.top_k, | |
| repetition_penalty=req.repetition_penalty, | |
| do_sample=req.do_sample, | |
| num_return_sequences=req.num_return_sequences, | |
| pad_token_id=tokenizer.eos_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| return_full_text=True, | |
| ) | |
| text = gen[0]["generated_text"] | |
| output = text.split("OUTPUT:", 1)[-1].strip() | |
| return GenerateResponse(output=output, prompt=prompt, parameters=req.model_dump()) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |