import sys import pathlib import os import torch import re from fastapi import FastAPI, HTTPException from pydantic import BaseModel sys.path.append(str(pathlib.Path(__file__).parent.resolve())) from tokenizer import Tokenizer from model.generator import Generator from model.encoder import Encoder from model.decoder import Decoder from model.attn import BahdanauAttention app = FastAPI() BASE_DIR = pathlib.Path(__file__).parent TOKENIZER_PATH = BASE_DIR / "tokenizer.json" CHECKPOINT_PATH = BASE_DIR / "best_model.pth" VOCAB_SIZE = 8000 EMBED_SIZE = 128 HIDDEN_SIZE = 256 NUM_LAYERS = 3 DROPOUT = 0.2 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') tokenizer = None model = None SOS_IDX = None EOS_IDX = None PAD_IDX = None class GenerationRequest(BaseModel): code_snippet: str cls: str = "parallel" # default max_len: int = 100 @app.on_event("startup") def load_resources(): global tokenizer, model, SOS_IDX, EOS_IDX, PAD_IDX if not TOKENIZER_PATH.exists(): raise FileNotFoundError(f"Tokenizer not found at {TOKENIZER_PATH}") tokenizer = Tokenizer(vocab_size=8000) tokenizer.load(str(TOKENIZER_PATH)) SOS_IDX = tokenizer.char2idx[''] EOS_IDX = tokenizer.char2idx[''] PAD_IDX = tokenizer.char2idx[''] actual_vocab_size = tokenizer.vocab_size encoder = Encoder(actual_vocab_size, EMBED_SIZE, HIDDEN_SIZE, NUM_LAYERS, DROPOUT) attention = BahdanauAttention(HIDDEN_SIZE) decoder = Decoder(actual_vocab_size, EMBED_SIZE, HIDDEN_SIZE, attention, NUM_LAYERS, DROPOUT) model = Generator(encoder, decoder, device).to(device) if not CHECKPOINT_PATH.exists(): print("WARNING: Checkpoint not found. Model will be random!") return checkpoint = torch.load(str(CHECKPOINT_PATH), map_location=device) model.load_state_dict(checkpoint['model_state_dict']) model.eval() def greedy_generate(code_snippet: str, cls: str, max_len: int) -> str: if model is None or tokenizer is None: raise HTTPException(status_code=503, detail="Model not loaded") model.eval() text = code_snippet if code_snippet.startswith("[CLS:") else f"[CLS:{cls}] {code_snippet}" input_ids = tokenizer.encode(text, max_length=1500, add_special_tokens=True) input_len = next((i for i, tok in enumerate(input_ids) if tok == PAD_IDX), len(input_ids)) input_tensor = torch.tensor([input_ids], device=device) input_len_tensor = torch.tensor([input_len], device=device) with torch.no_grad(): enc_outs, hidden, cell = model.encoder(input_tensor, input_len_tensor) mask = (torch.arange(enc_outs.size(1), device=device).unsqueeze(0) < input_len_tensor.unsqueeze(1)).float() hidden = hidden.view(model.encoder.num_layers, 2, 1, model.encoder.hidden_size) hidden = torch.cat((hidden[:, 0], hidden[:, 1]), dim=2) hidden = model.hidden_projection(hidden) cell = cell.view(model.encoder.num_layers, 2, 1, model.encoder.hidden_size) cell = torch.cat((cell[:, 0], cell[:, 1]), dim=2) cell = model.cell_projection(cell) input_token = torch.tensor([SOS_IDX], device=device) generated = [] for _ in range(max_len): output, hidden, cell, _ = model.decoder(input_token, hidden, cell, enc_outs, mask) top1 = output.argmax(1) token_id = top1.item() if token_id == EOS_IDX: break generated.append(token_id) input_token = top1 return tokenizer.decode(generated) @app.post("/generate") def generate_code_snippet(request: GenerationRequest): try: if not request.code_snippet.strip(): return {"pragma": ""} cleaned_code = request.code_snippet.strip() result = greedy_generate(cleaned_code, request.cls, request.max_len) return {"pragma": result} except Exception as e: raise HTTPException(status_code=500, detail=str(e))