Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -32,10 +32,11 @@ def generate_text(input_text, temperature, top_p, top_k):
|
|
| 32 |
global model, tokenizer
|
| 33 |
|
| 34 |
inputs = tokenizer(input_text, return_tensors="pt")
|
|
|
|
| 35 |
|
| 36 |
with torch.no_grad():
|
| 37 |
outputs = model.generate(
|
| 38 |
-
|
| 39 |
max_new_tokens=50,
|
| 40 |
temperature=temperature,
|
| 41 |
top_p=top_p,
|
|
@@ -46,9 +47,11 @@ def generate_text(input_text, temperature, top_p, top_k):
|
|
| 46 |
|
| 47 |
generated_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
|
| 48 |
|
| 49 |
-
#
|
| 50 |
-
|
| 51 |
-
|
|
|
|
|
|
|
| 52 |
|
| 53 |
# Visualiser l'attention
|
| 54 |
plt.figure(figsize=(10, 10))
|
|
@@ -58,12 +61,13 @@ def generate_text(input_text, temperature, top_p, top_k):
|
|
| 58 |
plt.close()
|
| 59 |
|
| 60 |
# Obtenir les mots les plus probables
|
| 61 |
-
probs = torch.nn.functional.softmax(
|
| 62 |
-
top_probs, top_indices = torch.topk(probs, k=5)
|
| 63 |
-
top_words = [tokenizer.decode([idx]) for idx in top_indices]
|
| 64 |
|
| 65 |
return generated_text, attention_plot, top_words
|
| 66 |
|
|
|
|
| 67 |
def reset():
|
| 68 |
return "", 1.0, 1.0, 50, None, None, None
|
| 69 |
|
|
|
|
| 32 |
global model, tokenizer
|
| 33 |
|
| 34 |
inputs = tokenizer(input_text, return_tensors="pt")
|
| 35 |
+
input_ids = inputs["input_ids"]
|
| 36 |
|
| 37 |
with torch.no_grad():
|
| 38 |
outputs = model.generate(
|
| 39 |
+
input_ids,
|
| 40 |
max_new_tokens=50,
|
| 41 |
temperature=temperature,
|
| 42 |
top_p=top_p,
|
|
|
|
| 47 |
|
| 48 |
generated_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
|
| 49 |
|
| 50 |
+
# Obtenir les logits pour le dernier token généré
|
| 51 |
+
last_token_logits = model(outputs.sequences[:, -1:]).logits[:, -1, :]
|
| 52 |
+
|
| 53 |
+
# Extraire les attentions
|
| 54 |
+
attentions = outputs.attentions[-1][-1].mean(dim=0).numpy()
|
| 55 |
|
| 56 |
# Visualiser l'attention
|
| 57 |
plt.figure(figsize=(10, 10))
|
|
|
|
| 61 |
plt.close()
|
| 62 |
|
| 63 |
# Obtenir les mots les plus probables
|
| 64 |
+
probs = torch.nn.functional.softmax(last_token_logits, dim=-1)
|
| 65 |
+
top_probs, top_indices = torch.topk(probs[0], k=5)
|
| 66 |
+
top_words = [tokenizer.decode([idx.item()]) for idx in top_indices]
|
| 67 |
|
| 68 |
return generated_text, attention_plot, top_words
|
| 69 |
|
| 70 |
+
|
| 71 |
def reset():
|
| 72 |
return "", 1.0, 1.0, 50, None, None, None
|
| 73 |
|