Spaces:
Sleeping
Sleeping
| import os | |
| from typing import Dict, Any | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| import torch | |
| # === Config === | |
| MODEL_ID = os.getenv("MODEL_ID", "ethnmcl/checkin-lora-gpt2") | |
| HF_TOKEN = os.getenv("HF_TOKEN") # if the repo is private, set this in Secrets | |
| app = FastAPI(title="Check-in GPT-2 API", version="1.1.0") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], | |
| ) | |
| # Choose device: GPU index 0 if available else CPU | |
| device = 0 if torch.cuda.is_available() else -1 | |
| # === Load tokenizer === | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # === Load model (supports plain CausalLM repos AND PEFT LoRA adapters) === | |
| # Strategy: | |
| # 1) Try plain AutoModelForCausalLM | |
| # 2) If that fails (likely LoRA-only repo), try PEFT AutoPeftModelForCausalLM and merge | |
| _model = None | |
| _merged = False | |
| try: | |
| _model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| token=HF_TOKEN, | |
| # Use 'dtype' not deprecated 'torch_dtype' | |
| dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| device_map="auto" if torch.cuda.is_available() else None, | |
| ) | |
| except Exception as e_plain: | |
| # Fall back to PEFT path | |
| try: | |
| from peft import AutoPeftModelForCausalLM | |
| _model = AutoPeftModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| token=HF_TOKEN, | |
| dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| device_map="auto" if torch.cuda.is_available() else None, | |
| ) | |
| # Merge LoRA into base weights so inference behaves like a standard CausalLM | |
| try: | |
| _model = _model.merge_and_unload() | |
| _merged = True | |
| except Exception: | |
| # If merge not available, we still can run with adapters active | |
| _merged = False | |
| except Exception as e_peft: | |
| raise RuntimeError( | |
| f"Failed to load model '{MODEL_ID}'. " | |
| f"Plain load error: {e_plain}\nPEFT load error: {e_peft}" | |
| ) | |
| # Build pipeline | |
| pipe = pipeline( | |
| "text-generation", | |
| model=_model, | |
| tokenizer=tokenizer, | |
| device=device, | |
| ) | |
| # Prompt shape (keep if you rely on INPUT/OUTPUT markers; otherwise switch to 'Check-in: ') | |
| PREFIX = "INPUT: " | |
| SUFFIX = "\nOUTPUT:" | |
| def make_prompt(user_input: str) -> str: | |
| return f"{PREFIX}{user_input}{SUFFIX}" | |
| class GenerateRequest(BaseModel): | |
| input: str = Field(..., min_length=1) | |
| max_new_tokens: int = 180 | |
| temperature: float = 0.7 | |
| top_p: float = 0.95 | |
| top_k: int = 50 | |
| repetition_penalty: float = 1.05 | |
| do_sample: bool = True | |
| num_return_sequences: int = 1 | |
| class GenerateResponse(BaseModel): | |
| output: str | |
| prompt: str | |
| parameters: Dict[str, Any] | |
| def root(): | |
| return { | |
| "message": "Check-in GPT-2 API. POST /generate", | |
| "model": MODEL_ID, | |
| "device": "cuda" if device == 0 else "cpu", | |
| "merged_lora": _merged, | |
| } | |
| def health(): | |
| return {"status": "ok"} | |
| def generate(req: GenerateRequest): | |
| try: | |
| prompt = make_prompt(req.input) | |
| gen = pipe( | |
| prompt, | |
| max_new_tokens=req.max_new_tokens, | |
| temperature=req.temperature, | |
| top_p=req.top_p, | |
| top_k=req.top_k, | |
| repetition_penalty=req.repetition_penalty, | |
| do_sample=req.do_sample, | |
| num_return_sequences=req.num_return_sequences, | |
| pad_token_id=tokenizer.eos_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| return_full_text=True, | |
| ) | |
| text = gen[0]["generated_text"] | |
| output = text.split("OUTPUT:", 1)[-1].strip() | |
| return GenerateResponse(output=output, prompt=prompt, parameters=req.model_dump()) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |