Juan Manuel Hernández Roda
Add application file
a8ea8b8
# Modelos
import torch
import numpy as np
import networkx as nx
from transformers import AutoTokenizer, BertForPreTraining, AutoModelForCausalLM
# API
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
def compute_attention_rollout(attn_mean):
rollout_list = []
L, S, _ = attn_mean.shape
I = torch.eye(S, device=attn_mean.device)
acummulated = I.clone()
for layer_idx in range(L):
A = attn_mean[layer_idx]
# Se le suma la identidad para la residual connection que indica el paper
A = A + I
# Se normaliza
A = A / A.sum(dim=-1, keepdim=True).clamp_min(1e-12)
acummulated = A @ acummulated
rollout_list.append(acummulated.clone())
return torch.stack(rollout_list, dim=0)
def residual_and_normalize(attention_layers):
L, seq_len, _ = attention_layers.shape
augmented_attention = attention_layers.copy()
identity_matrix = np.eye(seq_len)
for layer_idx in range(L):
# Conexión residual
augmented_attention[layer_idx] += identity_matrix
# Normalización
row_sums = augmented_attention[layer_idx].sum(axis=-1, keepdims=True)
augmented_attention[layer_idx] /= row_sums
return augmented_attention
def get_node_index(layer_idx, token_position, seq_len):
# El índice del nodo se calcula como el número de capa por la secuencia y
# la posición del token en esa capa
return layer_idx * seq_len + token_position
def build_attention_graph(augmented_attentions):
L, T, _ = augmented_attentions.shape
G = nx.DiGraph()
total_nodes = (L + 1) * T # Nodos: todas las capas + capa de entrada
super_sink = total_nodes # Añadimos super nodo
G.add_nodes_from(range(total_nodes + 1)) # Añadir todos los nodos del grafo
# Crear aristas con capacidad según las matrices de atención
for layer_idx in range(1, L + 1):
for token_from in range(T):
# Se obtiene el índice del token que observa al otro
u = get_node_index(layer_idx, token_from, T)
for token_to in range(T):
# Se obtiene el índice del token que es observado
v = get_node_index(layer_idx - 1, token_to, T)
# Se obtiene su atención (capacidad de flujo del que observa hacia el que es observado)
capacity = float(augmented_attentions[layer_idx - 1, token_from, token_to])
if capacity > 0:
G.add_edge(u, v, capacity=capacity)
for token_to in range(T):
v = get_node_index(0, token_to, T)
G.add_edge(v, super_sink, capacity=float(1e3))
return G, super_sink
def compute_attention_flow_matrices(layers_mean):
A = np.asarray(layers_mean) # (L, T, T)
L, T, _ = A.shape
# Agrega residual y normaliza las matrices de atención
aug = residual_and_normalize(A)
# Construye el grafo de flujo (edges: capa i → capa i-1)
G, super_sink = build_attention_graph(aug)
# Índices de los nodos de la capa 0 (tokens de entrada)
input_nodes = [get_node_index(0, v, T) for v in range(T)]
flow_layers = []
for layer_idx in range(1, L + 1):
layer_flow = np.zeros((T, T), dtype=np.float64)
for u in range(T):
src = get_node_index(layer_idx, u, T)
flow_val, flow_dict = nx.maximum_flow(G, src, super_sink, flow_func=nx.algorithms.flow.preflow_push)
row = np.zeros(T)
for v, node_in in enumerate(input_nodes):
row[v] = float(flow_dict.get(node_in, {}).get(super_sink, 0))
# Normalización
s = row.sum()
row /= s
layer_flow[u, :] = row
flow_layers.append(layer_flow)
return flow_layers
def process_prompt(prompt):
inputs = tokenizer(prompt, return_tensors="pt", return_offsets_mapping=True).to(model.device)
offsets = inputs.pop("offset_mapping")[0].tolist()
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
with torch.no_grad():
outputs = model(**inputs, output_attentions=True, return_dict=True)
attn = torch.stack(outputs.attentions, dim=0).squeeze(1)
att_mean = attn.mean(dim=1)
rollout = compute_attention_rollout(att_mean)
flow = compute_attention_flow_matrices(att_mean.detach().cpu().numpy())
layers_mean = [att_mean[l].detach().cpu().numpy().tolist() for l in range(att_mean.shape[0])]
attention_rollout = [rollout[l].detach().cpu().numpy().tolist() for l in range(rollout.shape[0])]
attention_flow = [flow[l].tolist() for l in range(len(flow))]
return {
"model": model_name,
"prompt": prompt,
"tokens": tokens,
"offsets": offsets,
"layers_mean": layers_mean,
"attention_rollout": attention_rollout,
"attention_flow": attention_flow
}
print(torch.__version__)
print(torch.cuda.is_available())
name = "gpt2"
if name == "gpt2":
model_name = "gpt2"
elif name == "bert":
model_name = "bert-base-uncased"
elif name == "qwen":
model_name = "Qwen/Qwen3-1.7B"
device = "cuda" if torch.cuda.is_available() else "cpu"
if name == "bert":
model = BertForPreTraining.from_pretrained(model_name, attn_implementation="eager")
else:
model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation="eager")
model.eval().to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# API
app = FastAPI(title="Attention Server", version="1.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class AttnIn(BaseModel):
prompt: str = Field(..., description="Texto de entrada")
@app.get("/health")
def health():
return {"status": "ok", "model": model_name, "device": device}
@app.post("/attentions")
def attentions(payload: AttnIn):
return process_prompt(payload.prompt)