File size: 5,343 Bytes
42fb1be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# main.py (Corrected)

import logging
from contextlib import asynccontextmanager
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, GPT2Config
from huggingface_hub import hf_hub_download

# --- IMPORTANT: We must import our custom model class directly ---
# This assumes 'modeling_rx_codex_v3.py' is in the same directory
from modeling_rx_codex_v3 import Rx_Codex_V3_Custom_Model_Class

# --- Configuration ---
HF_REPO_ID = "rxmha125/Rx_Codex_V1_Tiny_V3"
MODEL_LOAD_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# --- Logging Setup ---
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# --- Global variables ---
model = None
tokenizer = None

# --- Application Lifespan (Model Loading) ---
@asynccontextmanager
async def lifespan(app: FastAPI):
    global model, tokenizer
    logger.info(f"API Startup: Explicitly loading model '{HF_REPO_ID}' to device '{MODEL_LOAD_DEVICE}'...")
    
    try:
        # Load tokenizer as before
        tokenizer = AutoTokenizer.from_pretrained(HF_REPO_ID)
        logger.info("βœ… Tokenizer loaded successfully.")

        # --- EXPLICIT MODEL LOADING ---
        # 1. Load the configuration file
        config = GPT2Config.from_pretrained(HF_REPO_ID)
        logger.info("βœ… Config loaded successfully.")

        # 2. Instantiate our custom model with the config
        model = Rx_Codex_V3_Custom_Model_Class(config)
        logger.info("βœ… Custom model architecture instantiated.")

        # 3. Download the model weights file specifically
        weights_path = hf_hub_download(repo_id=HF_REPO_ID, filename="pytorch_model.bin")
        logger.info("βœ… Model weights downloaded successfully.")
        
        # 4. Load the state dictionary into our custom model
        state_dict = torch.load(weights_path, map_location=MODEL_LOAD_DEVICE)
        model.load_state_dict(state_dict)
        logger.info("βœ… Weights loaded into custom model successfully.")
        
        # 5. Move to device and set to evaluation mode
        model.to(MODEL_LOAD_DEVICE)
        model.eval()
        logger.info("βœ… Model is fully loaded and ready on the target device.")

    except Exception as e:
        logger.error(f"❌ FATAL: An error occurred during model loading: {e}", exc_info=True)
        # Set model to None to ensure API returns "not ready"
        model = None 
        tokenizer = None

    yield

    # --- Code below this line runs on shutdown ---
    logger.info("API Shutting down.")
    model = None
    tokenizer = None

# --- Initialize FastAPI ---
app = FastAPI(
    title="Rx Codex V1-Tiny-V3 API",
    description="An API for generating text with the Rx_Codex_V1_Tiny_V3 model.",
    lifespan=lifespan
)

# --- Pydantic Models for API Data Validation ---
class GenerationRequest(BaseModel):
    prompt: str
    max_new_tokens: int = 150
    temperature: float = 0.7
    top_k: int = 50

class GenerationResponse(BaseModel):
    generated_text: str

# --- API Endpoints ---
@app.get("/")
def root():
    """A simple endpoint to check if the API is running."""
    status = "loaded" if model and tokenizer else "not loaded"
    return {"message": "Rx Codex V1-Tiny-V3 API is running", "model_status": status}

@app.post("/generate", response_model=GenerationResponse)
async def generate_text(request: GenerationRequest):
    """The main endpoint to generate text from a prompt."""
    if not model or not tokenizer:
        raise HTTPException(status_code=503, detail="Model is not ready. Please try again later.")
    
    logger.info(f"Received generation request for prompt: '{request.prompt}'")
    
    formatted_prompt = f"### Human:\n{request.prompt}\n\n### Assistant:"
    inputs = tokenizer(formatted_prompt, return_tensors="pt").to(MODEL_LOAD_DEVICE)
    
    # --- NOTE: Our custom model does not have a .generate() method ---
    # We must use our manual generation loop
    output_ids = inputs["input_ids"]
    with torch.no_grad():
        for _ in range(request.max_new_tokens):
            outputs = model(output_ids)
            next_token_logits = outputs['logits'][:, -1, :]
            
            # Apply temperature
            if request.temperature > 0:
                next_token_logits = next_token_logits / request.temperature
            
            # Apply top-k
            if request.top_k > 0:
                v, _ = torch.topk(next_token_logits, min(request.top_k, next_token_logits.size(-1)))
                next_token_logits[next_token_logits < v[:, [-1]]] = -float('Inf')
            
            probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
            next_token_id = torch.multinomial(probs, num_samples=1)
            
            # Stop if EOS token is generated
            if next_token_id == tokenizer.eos_token_id:
                break
                
            output_ids = torch.cat((output_ids, next_token_id), dim=1)

    full_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    generated_text = full_text[len(formatted_prompt):].strip()

    logger.info("Generation complete.")
    return GenerationResponse(generated_text=generated_text)