mohamedahraf273 commited on
Commit
5202b5c
·
1 Parent(s): 3e4a1d2
Files changed (4) hide show
  1. Dockerfile +16 -0
  2. app.py +114 -0
  3. generator.ipynb +17 -2
  4. requirements.txt +12 -0
Dockerfile ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
2
+ # you will also find guides on how best to write your Dockerfile
3
+
4
+ FROM python:3.9
5
+
6
+ RUN useradd -m -u 1000 user
7
+ USER user
8
+ ENV PATH="/home/user/.local/bin:$PATH"
9
+
10
+ WORKDIR /app
11
+
12
+ COPY --chown=user ./requirements.txt requirements.txt
13
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
14
+
15
+ COPY --chown=user . /app
16
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import pathlib
3
+ import os
4
+ import torch
5
+ import re
6
+ from fastapi import FastAPI, HTTPException
7
+ from pydantic import BaseModel
8
+
9
+ sys.path.append(str(pathlib.Path(__file__).parent.resolve()))
10
+
11
+ from tokenizer import Tokenizer
12
+ from model.generator import Generator
13
+ from model.encoder import Encoder
14
+ from model.decoder import Decoder
15
+ from model.attn import BahdanauAttention
16
+
17
+ app = FastAPI()
18
+
19
+ BASE_DIR = pathlib.Path(__file__).parent
20
+ TOKENIZER_PATH = BASE_DIR / "tokenizer.json"
21
+ CHECKPOINT_PATH = BASE_DIR / "best_model.pth"
22
+ VOCAB_SIZE = 8000
23
+ EMBED_SIZE = 128
24
+ HIDDEN_SIZE = 256
25
+ NUM_LAYERS = 3
26
+ DROPOUT = 0.2
27
+
28
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
29
+ tokenizer = None
30
+ model = None
31
+ SOS_IDX = None
32
+ EOS_IDX = None
33
+ PAD_IDX = None
34
+
35
+ class GenerationRequest(BaseModel):
36
+ code_snippet: str
37
+ cls: str = "parallel" # default
38
+ max_len: int = 100
39
+
40
+
41
+ @app.on_event("startup")
42
+ def load_resources():
43
+ global tokenizer, model, SOS_IDX, EOS_IDX, PAD_IDX
44
+
45
+ if not TOKENIZER_PATH.exists():
46
+ raise FileNotFoundError(f"Tokenizer not found at {TOKENIZER_PATH}")
47
+
48
+ tokenizer = Tokenizer(vocab_size=8000)
49
+ tokenizer.load(str(TOKENIZER_PATH))
50
+ SOS_IDX = tokenizer.char2idx['<SOS>']
51
+ EOS_IDX = tokenizer.char2idx['<EOS>']
52
+ PAD_IDX = tokenizer.char2idx['<PAD>']
53
+ actual_vocab_size = tokenizer.vocab_size
54
+ encoder = Encoder(actual_vocab_size, EMBED_SIZE, HIDDEN_SIZE, NUM_LAYERS, DROPOUT)
55
+ attention = BahdanauAttention(HIDDEN_SIZE)
56
+ decoder = Decoder(actual_vocab_size, EMBED_SIZE, HIDDEN_SIZE, attention, NUM_LAYERS, DROPOUT)
57
+ model = Generator(encoder, decoder, device).to(device)
58
+ if not CHECKPOINT_PATH.exists():
59
+ print("WARNING: Checkpoint not found. Model will be random!")
60
+ return
61
+
62
+ checkpoint = torch.load(str(CHECKPOINT_PATH), map_location=device)
63
+ model.load_state_dict(checkpoint['model_state_dict'])
64
+ model.eval()
65
+
66
+ def greedy_generate(code_snippet: str, cls: str, max_len: int) -> str:
67
+ if model is None or tokenizer is None:
68
+ raise HTTPException(status_code=503, detail="Model not loaded")
69
+
70
+ model.eval()
71
+ text = code_snippet if code_snippet.startswith("[CLS:") else f"[CLS:{cls}] {code_snippet}"
72
+ input_ids = tokenizer.encode(text, max_length=1500, add_special_tokens=True)
73
+ input_len = next((i for i, tok in enumerate(input_ids) if tok == PAD_IDX), len(input_ids))
74
+ input_tensor = torch.tensor([input_ids], device=device)
75
+ input_len_tensor = torch.tensor([input_len], device=device)
76
+
77
+ with torch.no_grad():
78
+ enc_outs, hidden, cell = model.encoder(input_tensor, input_len_tensor)
79
+ mask = (torch.arange(enc_outs.size(1), device=device).unsqueeze(0) < input_len_tensor.unsqueeze(1)).float()
80
+ hidden = hidden.view(model.encoder.num_layers, 2, 1, model.encoder.hidden_size)
81
+ hidden = torch.cat((hidden[:, 0], hidden[:, 1]), dim=2)
82
+ hidden = model.hidden_projection(hidden)
83
+ cell = cell.view(model.encoder.num_layers, 2, 1, model.encoder.hidden_size)
84
+ cell = torch.cat((cell[:, 0], cell[:, 1]), dim=2)
85
+ cell = model.cell_projection(cell)
86
+ input_token = torch.tensor([SOS_IDX], device=device)
87
+ generated = []
88
+
89
+ for _ in range(max_len):
90
+ output, hidden, cell, _ = model.decoder(input_token, hidden, cell, enc_outs, mask)
91
+ top1 = output.argmax(1)
92
+ token_id = top1.item()
93
+
94
+ if token_id == EOS_IDX:
95
+ break
96
+
97
+ generated.append(token_id)
98
+ input_token = top1
99
+
100
+ return tokenizer.decode(generated)
101
+
102
+
103
+ @app.post("/generate")
104
+ def generate_code_snippet(request: GenerationRequest):
105
+ try:
106
+ if not request.code_snippet.strip():
107
+ return {"pragma": ""}
108
+
109
+ cleaned_code = request.code_snippet.strip()
110
+ result = greedy_generate(cleaned_code, request.cls, request.max_len)
111
+ return {"pragma": result}
112
+ except Exception as e:
113
+ raise HTTPException(status_code=500, detail=str(e))
114
+
generator.ipynb CHANGED
@@ -309,10 +309,25 @@
309
  },
310
  {
311
  "cell_type": "code",
312
- "execution_count": null,
313
  "id": "6d9a8e25",
314
  "metadata": {},
315
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
  "source": [
317
  "\n",
318
  "import sys\n",
 
309
  },
310
  {
311
  "cell_type": "code",
312
+ "execution_count": 18,
313
  "id": "6d9a8e25",
314
  "metadata": {},
315
+ "outputs": [
316
+ {
317
+ "name": "stdout",
318
+ "output_type": "stream",
319
+ "text": [
320
+ "Loaded checkpoint from best_model.pth (epoch 8)\n",
321
+ "Sample input (truncated): [CLS:reduction] for (i = 0; i < 1000; ++i)\n",
322
+ "{\n",
323
+ " logic_and = logic_and && logics[i];\n",
324
+ "}\n",
325
+ "\n",
326
+ "Reference pragma: omp parallel for schedule(dynamic,1) private(i) reduction(&&:logic_and)\n",
327
+ "Greedy prediction: omp parallel for schedule(dynamic,1) private(i) reduction(&&:logic_and)\n"
328
+ ]
329
+ }
330
+ ],
331
  "source": [
332
  "\n",
333
  "import sys\n",
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core Python Utilities
2
+ setuptools
3
+ regex
4
+ packaging
5
+ build
6
+ dm-tree
7
+ scikit-learn
8
+ pandas
9
+ numpy
10
+ torch
11
+ fastapi
12
+ uvicorn[standard]