File size: 4,044 Bytes
5202b5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import sys
import pathlib
import os
import torch
import re
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel

sys.path.append(str(pathlib.Path(__file__).parent.resolve()))

from tokenizer import Tokenizer
from model.generator import Generator
from model.encoder import Encoder
from model.decoder import Decoder
from model.attn import BahdanauAttention

app = FastAPI()

BASE_DIR = pathlib.Path(__file__).parent
TOKENIZER_PATH = BASE_DIR / "tokenizer.json"
CHECKPOINT_PATH = BASE_DIR / "best_model.pth"
VOCAB_SIZE = 8000
EMBED_SIZE = 128
HIDDEN_SIZE = 256
NUM_LAYERS = 3
DROPOUT = 0.2

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = None
model = None
SOS_IDX = None
EOS_IDX = None
PAD_IDX = None

class GenerationRequest(BaseModel):
    code_snippet: str
    cls: str = "parallel" # default
    max_len: int = 100


@app.on_event("startup")
def load_resources():
    global tokenizer, model, SOS_IDX, EOS_IDX, PAD_IDX
    
    if not TOKENIZER_PATH.exists():
        raise FileNotFoundError(f"Tokenizer not found at {TOKENIZER_PATH}")
    
    tokenizer = Tokenizer(vocab_size=8000)
    tokenizer.load(str(TOKENIZER_PATH))
    SOS_IDX = tokenizer.char2idx['<SOS>']
    EOS_IDX = tokenizer.char2idx['<EOS>']
    PAD_IDX = tokenizer.char2idx['<PAD>']
    actual_vocab_size = tokenizer.vocab_size
    encoder = Encoder(actual_vocab_size, EMBED_SIZE, HIDDEN_SIZE, NUM_LAYERS, DROPOUT)
    attention = BahdanauAttention(HIDDEN_SIZE)
    decoder = Decoder(actual_vocab_size, EMBED_SIZE, HIDDEN_SIZE, attention, NUM_LAYERS, DROPOUT)
    model = Generator(encoder, decoder, device).to(device)
    if not CHECKPOINT_PATH.exists():
        print("WARNING: Checkpoint not found. Model will be random!")
        return

    checkpoint = torch.load(str(CHECKPOINT_PATH), map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

def greedy_generate(code_snippet: str, cls: str, max_len: int) -> str:
    if model is None or tokenizer is None:
        raise HTTPException(status_code=503, detail="Model not loaded")

    model.eval()
    text = code_snippet if code_snippet.startswith("[CLS:") else f"[CLS:{cls}] {code_snippet}"
    input_ids = tokenizer.encode(text, max_length=1500, add_special_tokens=True)
    input_len = next((i for i, tok in enumerate(input_ids) if tok == PAD_IDX), len(input_ids))
    input_tensor = torch.tensor([input_ids], device=device)
    input_len_tensor = torch.tensor([input_len], device=device)

    with torch.no_grad():
        enc_outs, hidden, cell = model.encoder(input_tensor, input_len_tensor)
        mask = (torch.arange(enc_outs.size(1), device=device).unsqueeze(0) < input_len_tensor.unsqueeze(1)).float()
        hidden = hidden.view(model.encoder.num_layers, 2, 1, model.encoder.hidden_size)
        hidden = torch.cat((hidden[:, 0], hidden[:, 1]), dim=2)
        hidden = model.hidden_projection(hidden)
        cell = cell.view(model.encoder.num_layers, 2, 1, model.encoder.hidden_size)
        cell = torch.cat((cell[:, 0], cell[:, 1]), dim=2)
        cell = model.cell_projection(cell)
        input_token = torch.tensor([SOS_IDX], device=device)
        generated = []
        
        for _ in range(max_len):
            output, hidden, cell, _ = model.decoder(input_token, hidden, cell, enc_outs, mask)
            top1 = output.argmax(1)
            token_id = top1.item()
            
            if token_id == EOS_IDX:
                break
                
            generated.append(token_id)
            input_token = top1

    return tokenizer.decode(generated)


@app.post("/generate")
def generate_code_snippet(request: GenerationRequest):
    try:
        if not request.code_snippet.strip():
             return {"pragma": ""}
             
        cleaned_code = request.code_snippet.strip()
        result = greedy_generate(cleaned_code, request.cls, request.max_len)
        return {"pragma": result}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))