Spaces:
Sleeping
Sleeping
| import sys | |
| import pathlib | |
| import os | |
| import torch | |
| import re | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| sys.path.append(str(pathlib.Path(__file__).parent.resolve())) | |
| from tokenizer import Tokenizer | |
| from model.generator import Generator | |
| from model.encoder import Encoder | |
| from model.decoder import Decoder | |
| from model.attn import BahdanauAttention | |
| app = FastAPI() | |
| BASE_DIR = pathlib.Path(__file__).parent | |
| TOKENIZER_PATH = BASE_DIR / "tokenizer.json" | |
| CHECKPOINT_PATH = BASE_DIR / "best_model.pth" | |
| VOCAB_SIZE = 8000 | |
| EMBED_SIZE = 128 | |
| HIDDEN_SIZE = 256 | |
| NUM_LAYERS = 3 | |
| DROPOUT = 0.2 | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| tokenizer = None | |
| model = None | |
| SOS_IDX = None | |
| EOS_IDX = None | |
| PAD_IDX = None | |
| class GenerationRequest(BaseModel): | |
| code_snippet: str | |
| cls: str = "parallel" # default | |
| max_len: int = 100 | |
| def load_resources(): | |
| global tokenizer, model, SOS_IDX, EOS_IDX, PAD_IDX | |
| if not TOKENIZER_PATH.exists(): | |
| raise FileNotFoundError(f"Tokenizer not found at {TOKENIZER_PATH}") | |
| tokenizer = Tokenizer(vocab_size=8000) | |
| tokenizer.load(str(TOKENIZER_PATH)) | |
| SOS_IDX = tokenizer.char2idx['<SOS>'] | |
| EOS_IDX = tokenizer.char2idx['<EOS>'] | |
| PAD_IDX = tokenizer.char2idx['<PAD>'] | |
| actual_vocab_size = tokenizer.vocab_size | |
| encoder = Encoder(actual_vocab_size, EMBED_SIZE, HIDDEN_SIZE, NUM_LAYERS, DROPOUT) | |
| attention = BahdanauAttention(HIDDEN_SIZE) | |
| decoder = Decoder(actual_vocab_size, EMBED_SIZE, HIDDEN_SIZE, attention, NUM_LAYERS, DROPOUT) | |
| model = Generator(encoder, decoder, device).to(device) | |
| if not CHECKPOINT_PATH.exists(): | |
| print("WARNING: Checkpoint not found. Model will be random!") | |
| return | |
| checkpoint = torch.load(str(CHECKPOINT_PATH), map_location=device) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model.eval() | |
| def greedy_generate(code_snippet: str, cls: str, max_len: int) -> str: | |
| if model is None or tokenizer is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| model.eval() | |
| text = code_snippet if code_snippet.startswith("[CLS:") else f"[CLS:{cls}] {code_snippet}" | |
| input_ids = tokenizer.encode(text, max_length=1500, add_special_tokens=True) | |
| input_len = next((i for i, tok in enumerate(input_ids) if tok == PAD_IDX), len(input_ids)) | |
| input_tensor = torch.tensor([input_ids], device=device) | |
| input_len_tensor = torch.tensor([input_len], device=device) | |
| with torch.no_grad(): | |
| enc_outs, hidden, cell = model.encoder(input_tensor, input_len_tensor) | |
| mask = (torch.arange(enc_outs.size(1), device=device).unsqueeze(0) < input_len_tensor.unsqueeze(1)).float() | |
| hidden = hidden.view(model.encoder.num_layers, 2, 1, model.encoder.hidden_size) | |
| hidden = torch.cat((hidden[:, 0], hidden[:, 1]), dim=2) | |
| hidden = model.hidden_projection(hidden) | |
| cell = cell.view(model.encoder.num_layers, 2, 1, model.encoder.hidden_size) | |
| cell = torch.cat((cell[:, 0], cell[:, 1]), dim=2) | |
| cell = model.cell_projection(cell) | |
| input_token = torch.tensor([SOS_IDX], device=device) | |
| generated = [] | |
| for _ in range(max_len): | |
| output, hidden, cell, _ = model.decoder(input_token, hidden, cell, enc_outs, mask) | |
| top1 = output.argmax(1) | |
| token_id = top1.item() | |
| if token_id == EOS_IDX: | |
| break | |
| generated.append(token_id) | |
| input_token = top1 | |
| return tokenizer.decode(generated) | |
| def generate_code_snippet(request: GenerationRequest): | |
| try: | |
| if not request.code_snippet.strip(): | |
| return {"pragma": ""} | |
| cleaned_code = request.code_snippet.strip() | |
| result = greedy_generate(cleaned_code, request.cls, request.max_len) | |
| return {"pragma": result} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |