yasserrmd commited on
Commit
ad12482
·
verified ·
1 Parent(s): 4e7fa9f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -8
app.py CHANGED
@@ -7,9 +7,16 @@ from synthid_text import synthid_mixin, logits_processing
7
  # Configurations and model selection
8
  MODEL_NAME = "google/gemma-7b-it"
9
  DEVICE = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
 
 
 
10
 
11
  # Initialize model and tokenizer
12
- model = transformers.AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE)
 
 
 
 
13
  tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
14
  tokenizer.pad_token = tokenizer.eos_token
15
  tokenizer.padding_side = "left"
@@ -24,22 +31,25 @@ CONFIG = synthid_mixin.DEFAULT_WATERMARKING_CONFIG
24
  def check_plagiarism(text):
25
  # Logits processor for SynthID
26
  logits_processor = logits_processing.SynthIDLogitsProcessor(
27
- **CONFIG, top_k=40, temperature=0.5
28
  )
29
 
30
  # Tokenize and process the input text
31
- inputs = tokenizer(text, return_tensors="pt").to(DEVICE)
32
  inputs_len = inputs['input_ids'].shape[1]
33
 
34
  # Generate output with model, capturing scores (logits)
35
  with torch.no_grad():
36
  outputs = model.generate(
37
- inputs['input_ids'],
38
- max_length=inputs_len + 50, # Generate up to 50 additional tokens
39
- output_scores=True,
40
- return_dict_in_generate=True
 
 
41
  )
42
-
 
43
  # Extract the generated tokens from the model's predictions
44
  generated_tokens = outputs.sequences[:, inputs_len:]
45
 
 
7
  # Configurations and model selection
8
  MODEL_NAME = "google/gemma-7b-it"
9
  DEVICE = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
10
+ TOP_K = 40
11
+ TOP_P = 0.99
12
+ TEMPERATURE= 0.5
13
 
14
  # Initialize model and tokenizer
15
+ model = synthid_mixin.SynthIDGemmaForCausalLM.from_pretrained(
16
+ MODEL_NAME,
17
+ device_map=DEVICE,
18
+ torch_dtype=torch.bfloat16,
19
+ )
20
  tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
21
  tokenizer.pad_token = tokenizer.eos_token
22
  tokenizer.padding_side = "left"
 
31
  def check_plagiarism(text):
32
  # Logits processor for SynthID
33
  logits_processor = logits_processing.SynthIDLogitsProcessor(
34
+ **CONFIG, top_k=TOP_K, temperature=TEMPERATURE
35
  )
36
 
37
  # Tokenize and process the input text
38
+ inputs = tokenizer(text, return_tensors="pt", padding=True).to(DEVICE)
39
  inputs_len = inputs['input_ids'].shape[1]
40
 
41
  # Generate output with model, capturing scores (logits)
42
  with torch.no_grad():
43
  outputs = model.generate(
44
+ **inputs,
45
+ do_sample=True,
46
+ max_length=1024,
47
+ temperature=TEMPERATURE,
48
+ top_k=TOP_K,
49
+ top_p=TOP_P,
50
  )
51
+
52
+
53
  # Extract the generated tokens from the model's predictions
54
  generated_tokens = outputs.sequences[:, inputs_len:]
55