Alleinzellgaenger commited on
Commit
1f4a0b6
·
1 Parent(s): 2a0587f

First draft for attention mechanism

Browse files
Files changed (3) hide show
  1. backend/app.py +3 -1
  2. frontend/index.html +3 -0
  3. frontend/script.js +61 -1
backend/app.py CHANGED
@@ -38,13 +38,15 @@ except Exception as e:
38
  raise
39
 
40
  @app.post("/process")
41
- async def process_text(text: str = Body(..., embed=True)):
42
  """
43
  Process the input text:
44
  - Tokenizes the text
45
  - Runs the GPT2 model to obtain attentions
46
  - Returns the tokens and attention values (rounded to 2 decimals)
47
  """
 
 
48
  try:
49
  logger.info(f"Received text: {text}")
50
  # Tokenize input text (truncating if needed)
 
38
  raise
39
 
40
  @app.post("/process")
41
+ async def process_text(payload: TextRequest):
42
  """
43
  Process the input text:
44
  - Tokenizes the text
45
  - Runs the GPT2 model to obtain attentions
46
  - Returns the tokens and attention values (rounded to 2 decimals)
47
  """
48
+
49
+ text = payload.txt
50
  try:
51
  logger.info(f"Received text: {text}")
52
  # Tokenize input text (truncating if needed)
frontend/index.html CHANGED
@@ -13,6 +13,9 @@
13
  </form>
14
  <div id="output">
15
  <!-- Processed output will be displayed here -->
 
 
 
16
  </div>
17
  <script src="/static/script.js"></script>
18
  </body>
 
13
  </form>
14
  <div id="output">
15
  <!-- Processed output will be displayed here -->
16
+ </div>
17
+ <div id="tokenContainer">
18
+
19
  </div>
20
  <script src="/static/script.js"></script>
21
  </body>
frontend/script.js CHANGED
@@ -24,12 +24,72 @@ document.getElementById('textForm').addEventListener('submit', async (e) => {
24
 
25
  // Function to display the tokens and attention values
26
  function displayOutput(data) {
 
27
  const outputDiv = document.getElementById('output');
28
  outputDiv.innerHTML = `
29
  <h2>Tokens</h2>
30
  <pre>${JSON.stringify(data.tokens, null, 2)}</pre>
31
  <h2>Attention</h2>
32
  <pre>${JSON.stringify(data.attention, null, 2)}</pre>
 
33
  `;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  }
35
-
 
24
 
25
  // Function to display the tokens and attention values
26
  function displayOutput(data) {
27
+ // Optionally still display the raw JSON if needed
28
  const outputDiv = document.getElementById('output');
29
  outputDiv.innerHTML = `
30
  <h2>Tokens</h2>
31
  <pre>${JSON.stringify(data.tokens, null, 2)}</pre>
32
  <h2>Attention</h2>
33
  <pre>${JSON.stringify(data.attention, null, 2)}</pre>
34
+ <div id="tokenContainer"></div>
35
  `;
36
+
37
+ // Render tokens using the first layer's attention (or adjust as needed)
38
+ if (data.attention && data.attention.length > 0) {
39
+ // For instance, using the first layer. You might refine this to select a specific head.
40
+ renderTokens(data.tokens, data.attention[0]);
41
+ }
42
+ }
43
+
44
+ function renderTokens(tokens, attentionMatrix){
45
+ const container = document.getElementById('tokenContainer')
46
+
47
+ container.innerHTML = "" // remove previous tokens
48
+
49
+ tokens.forEach((token, index) => {
50
+ const span = document.createElement("span");
51
+ span.textContent = token + " ";
52
+ span.style.transition = "font-size 0.2 ease";
53
+ span.dataset.tokenIndex = index;
54
+
55
+ span.addEventListener('mouseover', () =>{
56
+ highlightAttention(index, token, attentionMatrix);
57
+ });
58
+
59
+ span.addEventListener('mouseout', ()=> {
60
+ resetTokenSizes();
61
+ })
62
+
63
+ container.appendChild(span);
64
+ });
65
+ }
66
+
67
+ function highlightAttention(hoveredIndex, tokens, attentionMatrix) {
68
+ const container = document.getElementById("tokenContainer");
69
+ // Get the attention weights for the hovered token
70
+ const weights = attentionMatrix[hoveredIndex];
71
+
72
+ // Normalize weights if needed (optional)
73
+ const maxWeight = Math.max(...weights);
74
+
75
+ // Define a base font size and a maximum increase
76
+ const baseFontSize = 16; // in pixels
77
+ const maxIncrease = 10; // additional pixels for the maximum attention value
78
+
79
+ // Iterate over tokens and adjust font sizes
80
+ Array.from(container.children).forEach((span, idx) => {
81
+ // Calculate a new font size proportional to the attention weight
82
+ const weight = weights[idx];
83
+ // For example, increase font size linearly relative to the maxWeight
84
+ const newFontSize = baseFontSize + (weight / maxWeight) * maxIncrease;
85
+ span.style.fontSize = newFontSize + "px";
86
+ });
87
+ }
88
+
89
+ function resetTokenSizes() {
90
+ const container = document.getElementById("tokenContainer");
91
+ Array.from(container.children).forEach((span) => {
92
+ span.style.fontSize = "16px"; // Reset to base size
93
+ });
94
  }
95
+