mohamedahraf273's picture
add api
5202b5c
import sys
import pathlib
import os
import torch
import re
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
sys.path.append(str(pathlib.Path(__file__).parent.resolve()))
from tokenizer import Tokenizer
from model.generator import Generator
from model.encoder import Encoder
from model.decoder import Decoder
from model.attn import BahdanauAttention
app = FastAPI()
BASE_DIR = pathlib.Path(__file__).parent
TOKENIZER_PATH = BASE_DIR / "tokenizer.json"
CHECKPOINT_PATH = BASE_DIR / "best_model.pth"
VOCAB_SIZE = 8000
EMBED_SIZE = 128
HIDDEN_SIZE = 256
NUM_LAYERS = 3
DROPOUT = 0.2
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = None
model = None
SOS_IDX = None
EOS_IDX = None
PAD_IDX = None
class GenerationRequest(BaseModel):
code_snippet: str
cls: str = "parallel" # default
max_len: int = 100
@app.on_event("startup")
def load_resources():
global tokenizer, model, SOS_IDX, EOS_IDX, PAD_IDX
if not TOKENIZER_PATH.exists():
raise FileNotFoundError(f"Tokenizer not found at {TOKENIZER_PATH}")
tokenizer = Tokenizer(vocab_size=8000)
tokenizer.load(str(TOKENIZER_PATH))
SOS_IDX = tokenizer.char2idx['<SOS>']
EOS_IDX = tokenizer.char2idx['<EOS>']
PAD_IDX = tokenizer.char2idx['<PAD>']
actual_vocab_size = tokenizer.vocab_size
encoder = Encoder(actual_vocab_size, EMBED_SIZE, HIDDEN_SIZE, NUM_LAYERS, DROPOUT)
attention = BahdanauAttention(HIDDEN_SIZE)
decoder = Decoder(actual_vocab_size, EMBED_SIZE, HIDDEN_SIZE, attention, NUM_LAYERS, DROPOUT)
model = Generator(encoder, decoder, device).to(device)
if not CHECKPOINT_PATH.exists():
print("WARNING: Checkpoint not found. Model will be random!")
return
checkpoint = torch.load(str(CHECKPOINT_PATH), map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
def greedy_generate(code_snippet: str, cls: str, max_len: int) -> str:
if model is None or tokenizer is None:
raise HTTPException(status_code=503, detail="Model not loaded")
model.eval()
text = code_snippet if code_snippet.startswith("[CLS:") else f"[CLS:{cls}] {code_snippet}"
input_ids = tokenizer.encode(text, max_length=1500, add_special_tokens=True)
input_len = next((i for i, tok in enumerate(input_ids) if tok == PAD_IDX), len(input_ids))
input_tensor = torch.tensor([input_ids], device=device)
input_len_tensor = torch.tensor([input_len], device=device)
with torch.no_grad():
enc_outs, hidden, cell = model.encoder(input_tensor, input_len_tensor)
mask = (torch.arange(enc_outs.size(1), device=device).unsqueeze(0) < input_len_tensor.unsqueeze(1)).float()
hidden = hidden.view(model.encoder.num_layers, 2, 1, model.encoder.hidden_size)
hidden = torch.cat((hidden[:, 0], hidden[:, 1]), dim=2)
hidden = model.hidden_projection(hidden)
cell = cell.view(model.encoder.num_layers, 2, 1, model.encoder.hidden_size)
cell = torch.cat((cell[:, 0], cell[:, 1]), dim=2)
cell = model.cell_projection(cell)
input_token = torch.tensor([SOS_IDX], device=device)
generated = []
for _ in range(max_len):
output, hidden, cell, _ = model.decoder(input_token, hidden, cell, enc_outs, mask)
top1 = output.argmax(1)
token_id = top1.item()
if token_id == EOS_IDX:
break
generated.append(token_id)
input_token = top1
return tokenizer.decode(generated)
@app.post("/generate")
def generate_code_snippet(request: GenerationRequest):
try:
if not request.code_snippet.strip():
return {"pragma": ""}
cleaned_code = request.code_snippet.strip()
result = greedy_generate(cleaned_code, request.cls, request.max_len)
return {"pragma": result}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))