Spaces:
Runtime error
Runtime error
| # main.py (Corrected) | |
| import logging | |
| from contextlib import asynccontextmanager | |
| import torch | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from transformers import AutoTokenizer, GPT2Config | |
| from huggingface_hub import hf_hub_download | |
| # --- IMPORTANT: We must import our custom model class directly --- | |
| # This assumes 'modeling_rx_codex_v3.py' is in the same directory | |
| from modeling_rx_codex_v3 import Rx_Codex_V3_Custom_Model_Class | |
| # --- Configuration --- | |
| HF_REPO_ID = "rxmha125/Rx_Codex_V1_Tiny_V3" | |
| MODEL_LOAD_DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # --- Logging Setup --- | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # --- Global variables --- | |
| model = None | |
| tokenizer = None | |
| # --- Application Lifespan (Model Loading) --- | |
| async def lifespan(app: FastAPI): | |
| global model, tokenizer | |
| logger.info(f"API Startup: Explicitly loading model '{HF_REPO_ID}' to device '{MODEL_LOAD_DEVICE}'...") | |
| try: | |
| # Load tokenizer as before | |
| tokenizer = AutoTokenizer.from_pretrained(HF_REPO_ID) | |
| logger.info("β Tokenizer loaded successfully.") | |
| # --- EXPLICIT MODEL LOADING --- | |
| # 1. Load the configuration file | |
| config = GPT2Config.from_pretrained(HF_REPO_ID) | |
| logger.info("β Config loaded successfully.") | |
| # 2. Instantiate our custom model with the config | |
| model = Rx_Codex_V3_Custom_Model_Class(config) | |
| logger.info("β Custom model architecture instantiated.") | |
| # 3. Download the model weights file specifically | |
| weights_path = hf_hub_download(repo_id=HF_REPO_ID, filename="pytorch_model.bin") | |
| logger.info("β Model weights downloaded successfully.") | |
| # 4. Load the state dictionary into our custom model | |
| state_dict = torch.load(weights_path, map_location=MODEL_LOAD_DEVICE) | |
| model.load_state_dict(state_dict) | |
| logger.info("β Weights loaded into custom model successfully.") | |
| # 5. Move to device and set to evaluation mode | |
| model.to(MODEL_LOAD_DEVICE) | |
| model.eval() | |
| logger.info("β Model is fully loaded and ready on the target device.") | |
| except Exception as e: | |
| logger.error(f"β FATAL: An error occurred during model loading: {e}", exc_info=True) | |
| # Set model to None to ensure API returns "not ready" | |
| model = None | |
| tokenizer = None | |
| yield | |
| # --- Code below this line runs on shutdown --- | |
| logger.info("API Shutting down.") | |
| model = None | |
| tokenizer = None | |
| # --- Initialize FastAPI --- | |
| app = FastAPI( | |
| title="Rx Codex V1-Tiny-V3 API", | |
| description="An API for generating text with the Rx_Codex_V1_Tiny_V3 model.", | |
| lifespan=lifespan | |
| ) | |
| # --- Pydantic Models for API Data Validation --- | |
| class GenerationRequest(BaseModel): | |
| prompt: str | |
| max_new_tokens: int = 150 | |
| temperature: float = 0.7 | |
| top_k: int = 50 | |
| class GenerationResponse(BaseModel): | |
| generated_text: str | |
| # --- API Endpoints --- | |
| def root(): | |
| """A simple endpoint to check if the API is running.""" | |
| status = "loaded" if model and tokenizer else "not loaded" | |
| return {"message": "Rx Codex V1-Tiny-V3 API is running", "model_status": status} | |
| async def generate_text(request: GenerationRequest): | |
| """The main endpoint to generate text from a prompt.""" | |
| if not model or not tokenizer: | |
| raise HTTPException(status_code=503, detail="Model is not ready. Please try again later.") | |
| logger.info(f"Received generation request for prompt: '{request.prompt}'") | |
| formatted_prompt = f"### Human:\n{request.prompt}\n\n### Assistant:" | |
| inputs = tokenizer(formatted_prompt, return_tensors="pt").to(MODEL_LOAD_DEVICE) | |
| # --- NOTE: Our custom model does not have a .generate() method --- | |
| # We must use our manual generation loop | |
| output_ids = inputs["input_ids"] | |
| with torch.no_grad(): | |
| for _ in range(request.max_new_tokens): | |
| outputs = model(output_ids) | |
| next_token_logits = outputs['logits'][:, -1, :] | |
| # Apply temperature | |
| if request.temperature > 0: | |
| next_token_logits = next_token_logits / request.temperature | |
| # Apply top-k | |
| if request.top_k > 0: | |
| v, _ = torch.topk(next_token_logits, min(request.top_k, next_token_logits.size(-1))) | |
| next_token_logits[next_token_logits < v[:, [-1]]] = -float('Inf') | |
| probs = torch.nn.functional.softmax(next_token_logits, dim=-1) | |
| next_token_id = torch.multinomial(probs, num_samples=1) | |
| # Stop if EOS token is generated | |
| if next_token_id == tokenizer.eos_token_id: | |
| break | |
| output_ids = torch.cat((output_ids, next_token_id), dim=1) | |
| full_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
| generated_text = full_text[len(formatted_prompt):].strip() | |
| logger.info("Generation complete.") | |
| return GenerationResponse(generated_text=generated_text) |