msaifee commited on
Commit
bb2d5a0
·
verified ·
1 Parent(s): c292083

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -3
app.py CHANGED
@@ -15,11 +15,18 @@ tokenizer = AutoTokenizer.from_pretrained(model_name, token=api_token)
15
  model = AutoModelForCausalLM.from_pretrained(model_name, token=api_token)
16
 
17
  # Define the inference function
18
- def generate_text(prompt, max_length=100, temperature=0.7):
19
- inputs = tokenizer(prompt, return_tensors="pt")
20
- output = model.generate(inputs['input_ids'], max_length=max_length, temperature=temperature)
 
 
 
 
 
 
21
  return tokenizer.decode(output[0], skip_special_tokens=True)
22
 
 
23
  # Create the Gradio interface
24
  iface = gr.Interface(
25
  fn=generate_text,
 
15
  model = AutoModelForCausalLM.from_pretrained(model_name, token=api_token)
16
 
17
  # Define the inference function
18
+ def generate_text(prompt, max_length, temperature):
19
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
20
+ output = model.generate(
21
+ inputs['input_ids'],
22
+ max_length=max_length,
23
+ temperature=temperature,
24
+ do_sample=True, # Enable sampling
25
+ attention_mask=inputs['attention_mask'] # Pass attention mask
26
+ )
27
  return tokenizer.decode(output[0], skip_special_tokens=True)
28
 
29
+
30
  # Create the Gradio interface
31
  iface = gr.Interface(
32
  fn=generate_text,