Spaces:
Runtime error
Runtime error
File size: 5,343 Bytes
42fb1be | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | # main.py (Corrected)
import logging
from contextlib import asynccontextmanager
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, GPT2Config
from huggingface_hub import hf_hub_download
# --- IMPORTANT: We must import our custom model class directly ---
# This assumes 'modeling_rx_codex_v3.py' is in the same directory
from modeling_rx_codex_v3 import Rx_Codex_V3_Custom_Model_Class
# --- Configuration ---
HF_REPO_ID = "rxmha125/Rx_Codex_V1_Tiny_V3"
MODEL_LOAD_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# --- Logging Setup ---
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# --- Global variables ---
model = None
tokenizer = None
# --- Application Lifespan (Model Loading) ---
@asynccontextmanager
async def lifespan(app: FastAPI):
global model, tokenizer
logger.info(f"API Startup: Explicitly loading model '{HF_REPO_ID}' to device '{MODEL_LOAD_DEVICE}'...")
try:
# Load tokenizer as before
tokenizer = AutoTokenizer.from_pretrained(HF_REPO_ID)
logger.info("β
Tokenizer loaded successfully.")
# --- EXPLICIT MODEL LOADING ---
# 1. Load the configuration file
config = GPT2Config.from_pretrained(HF_REPO_ID)
logger.info("β
Config loaded successfully.")
# 2. Instantiate our custom model with the config
model = Rx_Codex_V3_Custom_Model_Class(config)
logger.info("β
Custom model architecture instantiated.")
# 3. Download the model weights file specifically
weights_path = hf_hub_download(repo_id=HF_REPO_ID, filename="pytorch_model.bin")
logger.info("β
Model weights downloaded successfully.")
# 4. Load the state dictionary into our custom model
state_dict = torch.load(weights_path, map_location=MODEL_LOAD_DEVICE)
model.load_state_dict(state_dict)
logger.info("β
Weights loaded into custom model successfully.")
# 5. Move to device and set to evaluation mode
model.to(MODEL_LOAD_DEVICE)
model.eval()
logger.info("β
Model is fully loaded and ready on the target device.")
except Exception as e:
logger.error(f"β FATAL: An error occurred during model loading: {e}", exc_info=True)
# Set model to None to ensure API returns "not ready"
model = None
tokenizer = None
yield
# --- Code below this line runs on shutdown ---
logger.info("API Shutting down.")
model = None
tokenizer = None
# --- Initialize FastAPI ---
app = FastAPI(
title="Rx Codex V1-Tiny-V3 API",
description="An API for generating text with the Rx_Codex_V1_Tiny_V3 model.",
lifespan=lifespan
)
# --- Pydantic Models for API Data Validation ---
class GenerationRequest(BaseModel):
prompt: str
max_new_tokens: int = 150
temperature: float = 0.7
top_k: int = 50
class GenerationResponse(BaseModel):
generated_text: str
# --- API Endpoints ---
@app.get("/")
def root():
"""A simple endpoint to check if the API is running."""
status = "loaded" if model and tokenizer else "not loaded"
return {"message": "Rx Codex V1-Tiny-V3 API is running", "model_status": status}
@app.post("/generate", response_model=GenerationResponse)
async def generate_text(request: GenerationRequest):
"""The main endpoint to generate text from a prompt."""
if not model or not tokenizer:
raise HTTPException(status_code=503, detail="Model is not ready. Please try again later.")
logger.info(f"Received generation request for prompt: '{request.prompt}'")
formatted_prompt = f"### Human:\n{request.prompt}\n\n### Assistant:"
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(MODEL_LOAD_DEVICE)
# --- NOTE: Our custom model does not have a .generate() method ---
# We must use our manual generation loop
output_ids = inputs["input_ids"]
with torch.no_grad():
for _ in range(request.max_new_tokens):
outputs = model(output_ids)
next_token_logits = outputs['logits'][:, -1, :]
# Apply temperature
if request.temperature > 0:
next_token_logits = next_token_logits / request.temperature
# Apply top-k
if request.top_k > 0:
v, _ = torch.topk(next_token_logits, min(request.top_k, next_token_logits.size(-1)))
next_token_logits[next_token_logits < v[:, [-1]]] = -float('Inf')
probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
next_token_id = torch.multinomial(probs, num_samples=1)
# Stop if EOS token is generated
if next_token_id == tokenizer.eos_token_id:
break
output_ids = torch.cat((output_ids, next_token_id), dim=1)
full_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
generated_text = full_text[len(formatted_prompt):].strip()
logger.info("Generation complete.")
return GenerationResponse(generated_text=generated_text) |