# main.py (Corrected) import logging from contextlib import asynccontextmanager import torch from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import AutoTokenizer, GPT2Config from huggingface_hub import hf_hub_download # --- IMPORTANT: We must import our custom model class directly --- # This assumes 'modeling_rx_codex_v3.py' is in the same directory from modeling_rx_codex_v3 import Rx_Codex_V3_Custom_Model_Class # --- Configuration --- HF_REPO_ID = "rxmha125/Rx_Codex_V1_Tiny_V3" MODEL_LOAD_DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # --- Logging Setup --- logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # --- Global variables --- model = None tokenizer = None # --- Application Lifespan (Model Loading) --- @asynccontextmanager async def lifespan(app: FastAPI): global model, tokenizer logger.info(f"API Startup: Explicitly loading model '{HF_REPO_ID}' to device '{MODEL_LOAD_DEVICE}'...") try: # Load tokenizer as before tokenizer = AutoTokenizer.from_pretrained(HF_REPO_ID) logger.info("✅ Tokenizer loaded successfully.") # --- EXPLICIT MODEL LOADING --- # 1. Load the configuration file config = GPT2Config.from_pretrained(HF_REPO_ID) logger.info("✅ Config loaded successfully.") # 2. Instantiate our custom model with the config model = Rx_Codex_V3_Custom_Model_Class(config) logger.info("✅ Custom model architecture instantiated.") # 3. Download the model weights file specifically weights_path = hf_hub_download(repo_id=HF_REPO_ID, filename="pytorch_model.bin") logger.info("✅ Model weights downloaded successfully.") # 4. Load the state dictionary into our custom model state_dict = torch.load(weights_path, map_location=MODEL_LOAD_DEVICE) model.load_state_dict(state_dict) logger.info("✅ Weights loaded into custom model successfully.") # 5. Move to device and set to evaluation mode model.to(MODEL_LOAD_DEVICE) model.eval() logger.info("✅ Model is fully loaded and ready on the target device.") except Exception as e: logger.error(f"❌ FATAL: An error occurred during model loading: {e}", exc_info=True) # Set model to None to ensure API returns "not ready" model = None tokenizer = None yield # --- Code below this line runs on shutdown --- logger.info("API Shutting down.") model = None tokenizer = None # --- Initialize FastAPI --- app = FastAPI( title="Rx Codex V1-Tiny-V3 API", description="An API for generating text with the Rx_Codex_V1_Tiny_V3 model.", lifespan=lifespan ) # --- Pydantic Models for API Data Validation --- class GenerationRequest(BaseModel): prompt: str max_new_tokens: int = 150 temperature: float = 0.7 top_k: int = 50 class GenerationResponse(BaseModel): generated_text: str # --- API Endpoints --- @app.get("/") def root(): """A simple endpoint to check if the API is running.""" status = "loaded" if model and tokenizer else "not loaded" return {"message": "Rx Codex V1-Tiny-V3 API is running", "model_status": status} @app.post("/generate", response_model=GenerationResponse) async def generate_text(request: GenerationRequest): """The main endpoint to generate text from a prompt.""" if not model or not tokenizer: raise HTTPException(status_code=503, detail="Model is not ready. Please try again later.") logger.info(f"Received generation request for prompt: '{request.prompt}'") formatted_prompt = f"### Human:\n{request.prompt}\n\n### Assistant:" inputs = tokenizer(formatted_prompt, return_tensors="pt").to(MODEL_LOAD_DEVICE) # --- NOTE: Our custom model does not have a .generate() method --- # We must use our manual generation loop output_ids = inputs["input_ids"] with torch.no_grad(): for _ in range(request.max_new_tokens): outputs = model(output_ids) next_token_logits = outputs['logits'][:, -1, :] # Apply temperature if request.temperature > 0: next_token_logits = next_token_logits / request.temperature # Apply top-k if request.top_k > 0: v, _ = torch.topk(next_token_logits, min(request.top_k, next_token_logits.size(-1))) next_token_logits[next_token_logits < v[:, [-1]]] = -float('Inf') probs = torch.nn.functional.softmax(next_token_logits, dim=-1) next_token_id = torch.multinomial(probs, num_samples=1) # Stop if EOS token is generated if next_token_id == tokenizer.eos_token_id: break output_ids = torch.cat((output_ids, next_token_id), dim=1) full_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) generated_text = full_text[len(formatted_prompt):].strip() logger.info("Generation complete.") return GenerationResponse(generated_text=generated_text)