Testsdft's picture
Upload 4 files
42fb1be verified
# main.py (Corrected)
import logging
from contextlib import asynccontextmanager
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, GPT2Config
from huggingface_hub import hf_hub_download
# --- IMPORTANT: We must import our custom model class directly ---
# This assumes 'modeling_rx_codex_v3.py' is in the same directory
from modeling_rx_codex_v3 import Rx_Codex_V3_Custom_Model_Class
# --- Configuration ---
HF_REPO_ID = "rxmha125/Rx_Codex_V1_Tiny_V3"
MODEL_LOAD_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# --- Logging Setup ---
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# --- Global variables ---
model = None
tokenizer = None
# --- Application Lifespan (Model Loading) ---
@asynccontextmanager
async def lifespan(app: FastAPI):
global model, tokenizer
logger.info(f"API Startup: Explicitly loading model '{HF_REPO_ID}' to device '{MODEL_LOAD_DEVICE}'...")
try:
# Load tokenizer as before
tokenizer = AutoTokenizer.from_pretrained(HF_REPO_ID)
logger.info("βœ… Tokenizer loaded successfully.")
# --- EXPLICIT MODEL LOADING ---
# 1. Load the configuration file
config = GPT2Config.from_pretrained(HF_REPO_ID)
logger.info("βœ… Config loaded successfully.")
# 2. Instantiate our custom model with the config
model = Rx_Codex_V3_Custom_Model_Class(config)
logger.info("βœ… Custom model architecture instantiated.")
# 3. Download the model weights file specifically
weights_path = hf_hub_download(repo_id=HF_REPO_ID, filename="pytorch_model.bin")
logger.info("βœ… Model weights downloaded successfully.")
# 4. Load the state dictionary into our custom model
state_dict = torch.load(weights_path, map_location=MODEL_LOAD_DEVICE)
model.load_state_dict(state_dict)
logger.info("βœ… Weights loaded into custom model successfully.")
# 5. Move to device and set to evaluation mode
model.to(MODEL_LOAD_DEVICE)
model.eval()
logger.info("βœ… Model is fully loaded and ready on the target device.")
except Exception as e:
logger.error(f"❌ FATAL: An error occurred during model loading: {e}", exc_info=True)
# Set model to None to ensure API returns "not ready"
model = None
tokenizer = None
yield
# --- Code below this line runs on shutdown ---
logger.info("API Shutting down.")
model = None
tokenizer = None
# --- Initialize FastAPI ---
app = FastAPI(
title="Rx Codex V1-Tiny-V3 API",
description="An API for generating text with the Rx_Codex_V1_Tiny_V3 model.",
lifespan=lifespan
)
# --- Pydantic Models for API Data Validation ---
class GenerationRequest(BaseModel):
prompt: str
max_new_tokens: int = 150
temperature: float = 0.7
top_k: int = 50
class GenerationResponse(BaseModel):
generated_text: str
# --- API Endpoints ---
@app.get("/")
def root():
"""A simple endpoint to check if the API is running."""
status = "loaded" if model and tokenizer else "not loaded"
return {"message": "Rx Codex V1-Tiny-V3 API is running", "model_status": status}
@app.post("/generate", response_model=GenerationResponse)
async def generate_text(request: GenerationRequest):
"""The main endpoint to generate text from a prompt."""
if not model or not tokenizer:
raise HTTPException(status_code=503, detail="Model is not ready. Please try again later.")
logger.info(f"Received generation request for prompt: '{request.prompt}'")
formatted_prompt = f"### Human:\n{request.prompt}\n\n### Assistant:"
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(MODEL_LOAD_DEVICE)
# --- NOTE: Our custom model does not have a .generate() method ---
# We must use our manual generation loop
output_ids = inputs["input_ids"]
with torch.no_grad():
for _ in range(request.max_new_tokens):
outputs = model(output_ids)
next_token_logits = outputs['logits'][:, -1, :]
# Apply temperature
if request.temperature > 0:
next_token_logits = next_token_logits / request.temperature
# Apply top-k
if request.top_k > 0:
v, _ = torch.topk(next_token_logits, min(request.top_k, next_token_logits.size(-1)))
next_token_logits[next_token_logits < v[:, [-1]]] = -float('Inf')
probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
next_token_id = torch.multinomial(probs, num_samples=1)
# Stop if EOS token is generated
if next_token_id == tokenizer.eos_token_id:
break
output_ids = torch.cat((output_ids, next_token_id), dim=1)
full_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
generated_text = full_text[len(formatted_prompt):].strip()
logger.info("Generation complete.")
return GenerationResponse(generated_text=generated_text)