PhunvVi commited on
Commit
a58a3a1
·
verified ·
1 Parent(s): caca526

Delete main.py

Browse files
Files changed (1) hide show
  1. main.py +0 -55
main.py DELETED
@@ -1,55 +0,0 @@
1
- from model import load_model, load_tokenizer
2
- from utils import clean_output
3
- import torch
4
- import shap
5
- from huggingface_hub import login
6
-
7
- def generate_questions(context, num_questions=3, max_length=64):
8
- tokenizer = load_tokenizer()
9
- model = load_model()
10
-
11
- input_prompt = f"generate question: {context.strip()}"
12
- inputs = tokenizer(input_prompt, return_tensors="pt", truncation=True, padding="longest").to(model.device)
13
-
14
- outputs = model.generate(
15
- input_ids=inputs["input_ids"],
16
- attention_mask=inputs["attention_mask"],
17
- max_length=max_length,
18
- num_return_sequences=num_questions,
19
- do_sample=True,
20
- top_p=0.95,
21
- temperature=1.0
22
- )
23
-
24
- decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)
25
- return clean_output(decoded)
26
-
27
- def get_shap_values(tokenizer, model, prompt):
28
- """
29
- Compute SHAP token attributions for a given prompt.
30
- """
31
- # Define wrapper prediction function
32
- def f(texts):
33
- # Tokenize the list of texts
34
- inputs = tokenizer(list(texts), return_tensors="pt", truncation=True, padding=True).to(model.device)
35
- with torch.no_grad():
36
- out = model.generate(
37
- input_ids=inputs["input_ids"],
38
- attention_mask=inputs["attention_mask"],
39
- max_length=64,
40
- do_sample=False,
41
- num_beams=2
42
- )
43
- # Return something SHAP can use (e.g., output logits or decoded text)
44
- # Here, we return the first token's id for each output as a simple example
45
- return out[:, 0].detach().cpu().numpy()
46
-
47
- explainer = shap.Explainer(f, tokenizer)
48
- shap_values = explainer([prompt])
49
-
50
- # Get tokens for visualization
51
- tokens = tokenizer.convert_ids_to_tokens(tokenizer(prompt, return_tensors="pt")["input_ids"][0])
52
- return shap_values.values[0], tokens
53
-
54
-
55
-