Spaces:
Running
Running
Commit ·
1f4a0b6
1
Parent(s): 2a0587f
First draft for attention mechanism
Browse files- backend/app.py +3 -1
- frontend/index.html +3 -0
- frontend/script.js +61 -1
backend/app.py
CHANGED
|
@@ -38,13 +38,15 @@ except Exception as e:
|
|
| 38 |
raise
|
| 39 |
|
| 40 |
@app.post("/process")
|
| 41 |
-
async def process_text(
|
| 42 |
"""
|
| 43 |
Process the input text:
|
| 44 |
- Tokenizes the text
|
| 45 |
- Runs the GPT2 model to obtain attentions
|
| 46 |
- Returns the tokens and attention values (rounded to 2 decimals)
|
| 47 |
"""
|
|
|
|
|
|
|
| 48 |
try:
|
| 49 |
logger.info(f"Received text: {text}")
|
| 50 |
# Tokenize input text (truncating if needed)
|
|
|
|
| 38 |
raise
|
| 39 |
|
| 40 |
@app.post("/process")
|
| 41 |
+
async def process_text(payload: TextRequest):
|
| 42 |
"""
|
| 43 |
Process the input text:
|
| 44 |
- Tokenizes the text
|
| 45 |
- Runs the GPT2 model to obtain attentions
|
| 46 |
- Returns the tokens and attention values (rounded to 2 decimals)
|
| 47 |
"""
|
| 48 |
+
|
| 49 |
+
text = payload.txt
|
| 50 |
try:
|
| 51 |
logger.info(f"Received text: {text}")
|
| 52 |
# Tokenize input text (truncating if needed)
|
frontend/index.html
CHANGED
|
@@ -13,6 +13,9 @@
|
|
| 13 |
</form>
|
| 14 |
<div id="output">
|
| 15 |
<!-- Processed output will be displayed here -->
|
|
|
|
|
|
|
|
|
|
| 16 |
</div>
|
| 17 |
<script src="/static/script.js"></script>
|
| 18 |
</body>
|
|
|
|
| 13 |
</form>
|
| 14 |
<div id="output">
|
| 15 |
<!-- Processed output will be displayed here -->
|
| 16 |
+
</div>
|
| 17 |
+
<div id="tokenContainer">
|
| 18 |
+
|
| 19 |
</div>
|
| 20 |
<script src="/static/script.js"></script>
|
| 21 |
</body>
|
frontend/script.js
CHANGED
|
@@ -24,12 +24,72 @@ document.getElementById('textForm').addEventListener('submit', async (e) => {
|
|
| 24 |
|
| 25 |
// Function to display the tokens and attention values
|
| 26 |
function displayOutput(data) {
|
|
|
|
| 27 |
const outputDiv = document.getElementById('output');
|
| 28 |
outputDiv.innerHTML = `
|
| 29 |
<h2>Tokens</h2>
|
| 30 |
<pre>${JSON.stringify(data.tokens, null, 2)}</pre>
|
| 31 |
<h2>Attention</h2>
|
| 32 |
<pre>${JSON.stringify(data.attention, null, 2)}</pre>
|
|
|
|
| 33 |
`;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
}
|
| 35 |
-
|
|
|
|
| 24 |
|
| 25 |
// Function to display the tokens and attention values
|
| 26 |
function displayOutput(data) {
|
| 27 |
+
// Optionally still display the raw JSON if needed
|
| 28 |
const outputDiv = document.getElementById('output');
|
| 29 |
outputDiv.innerHTML = `
|
| 30 |
<h2>Tokens</h2>
|
| 31 |
<pre>${JSON.stringify(data.tokens, null, 2)}</pre>
|
| 32 |
<h2>Attention</h2>
|
| 33 |
<pre>${JSON.stringify(data.attention, null, 2)}</pre>
|
| 34 |
+
<div id="tokenContainer"></div>
|
| 35 |
`;
|
| 36 |
+
|
| 37 |
+
// Render tokens using the first layer's attention (or adjust as needed)
|
| 38 |
+
if (data.attention && data.attention.length > 0) {
|
| 39 |
+
// For instance, using the first layer. You might refine this to select a specific head.
|
| 40 |
+
renderTokens(data.tokens, data.attention[0]);
|
| 41 |
+
}
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
function renderTokens(tokens, attentionMatrix){
|
| 45 |
+
const container = document.getElementById('tokenContainer')
|
| 46 |
+
|
| 47 |
+
container.innerHTML = "" // remove previous tokens
|
| 48 |
+
|
| 49 |
+
tokens.forEach((token, index) => {
|
| 50 |
+
const span = document.createElement("span");
|
| 51 |
+
span.textContent = token + " ";
|
| 52 |
+
span.style.transition = "font-size 0.2 ease";
|
| 53 |
+
span.dataset.tokenIndex = index;
|
| 54 |
+
|
| 55 |
+
span.addEventListener('mouseover', () =>{
|
| 56 |
+
highlightAttention(index, token, attentionMatrix);
|
| 57 |
+
});
|
| 58 |
+
|
| 59 |
+
span.addEventListener('mouseout', ()=> {
|
| 60 |
+
resetTokenSizes();
|
| 61 |
+
})
|
| 62 |
+
|
| 63 |
+
container.appendChild(span);
|
| 64 |
+
});
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
function highlightAttention(hoveredIndex, tokens, attentionMatrix) {
|
| 68 |
+
const container = document.getElementById("tokenContainer");
|
| 69 |
+
// Get the attention weights for the hovered token
|
| 70 |
+
const weights = attentionMatrix[hoveredIndex];
|
| 71 |
+
|
| 72 |
+
// Normalize weights if needed (optional)
|
| 73 |
+
const maxWeight = Math.max(...weights);
|
| 74 |
+
|
| 75 |
+
// Define a base font size and a maximum increase
|
| 76 |
+
const baseFontSize = 16; // in pixels
|
| 77 |
+
const maxIncrease = 10; // additional pixels for the maximum attention value
|
| 78 |
+
|
| 79 |
+
// Iterate over tokens and adjust font sizes
|
| 80 |
+
Array.from(container.children).forEach((span, idx) => {
|
| 81 |
+
// Calculate a new font size proportional to the attention weight
|
| 82 |
+
const weight = weights[idx];
|
| 83 |
+
// For example, increase font size linearly relative to the maxWeight
|
| 84 |
+
const newFontSize = baseFontSize + (weight / maxWeight) * maxIncrease;
|
| 85 |
+
span.style.fontSize = newFontSize + "px";
|
| 86 |
+
});
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
function resetTokenSizes() {
|
| 90 |
+
const container = document.getElementById("tokenContainer");
|
| 91 |
+
Array.from(container.children).forEach((span) => {
|
| 92 |
+
span.style.fontSize = "16px"; // Reset to base size
|
| 93 |
+
});
|
| 94 |
}
|
| 95 |
+
|