mssaidat commited on
Commit
896a6d3
·
verified ·
1 Parent(s): 70263d1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -15
app.py CHANGED
@@ -12,25 +12,44 @@ model = AutoModelForCausalLM.from_pretrained(
12
  torch_dtype="auto"
13
  )
14
 
 
 
 
 
 
 
15
  # Create generation pipeline
16
- story_generator = pipeline(
17
- "text-generation",
 
 
 
 
 
 
18
  model=model,
19
- tokenizer=tokenizer
 
20
  )
21
 
22
  # Function to generate stories
23
- def generate_story(prompt, max_length=300, temperature=0.8):
24
- outputs = story_generator(
25
- model,
26
- #prompt,
27
- max_length=max_length,
28
- temperature=temperature,
29
- do_sample=True,
30
- top_p=0.95,
31
- top_k=50
32
- )
33
- return outputs[0]["generated_text"]
 
 
 
 
 
 
34
 
35
  # Gradio UI
36
  with gr.Blocks() as demo:
@@ -42,7 +61,7 @@ with gr.Blocks() as demo:
42
  lines=3
43
  )
44
 
45
- max_length = gr.Slider(50, 1000, value=300, step=50, label="Story Length")
46
  temperature = gr.Slider(0.1, 1.5, value=0.8, step=0.1, label="Creativity")
47
 
48
  generate_btn = gr.Button("✨ Generate Story")
 
12
  torch_dtype="auto"
13
  )
14
 
15
+ # Ensure pad token is set for safe generation
16
+ if tokenizer.pad_token is None:
17
+ tokenizer.pad_token = tokenizer.eos_token
18
+ if getattr(model.config, "pad_token_id", None) is None:
19
+ model.config.pad_token_id = tokenizer.pad_token_id
20
+
21
  # Create generation pipeline
22
+ #story_generator = pipeline(
23
+ #"text-generation",
24
+ #model=model,
25
+ #tokenizer=tokenizer )
26
+
27
+ # Build a text generation pipeline
28
+ generator = pipeline(
29
+ task="text-generation",
30
  model=model,
31
+ tokenizer=tokenizer,
32
+ return_full_text=False
33
  )
34
 
35
  # Function to generate stories
36
+ def generate_story(prompt, max_tokens=300, temperature=0.8):
37
+ try:
38
+ chat_prompt = to_chat_prompt(prompt)
39
+ outputs = generator(
40
+ chat_prompt,
41
+ max_new_tokens=int(max_tokens),
42
+ temperature=float(temperature),
43
+ do_sample=True,
44
+ top_p=0.95,
45
+ top_k=50,
46
+ repetition_penalty=1.05,
47
+ pad_token_id=tokenizer.pad_token_id,
48
+ eos_token_id=tokenizer.eos_token_id
49
+ )
50
+ return outputs[0]["generated_text"]
51
+ except Exception as e:
52
+ return f"Error during generation: {type(e).__name__}: {e}"
53
 
54
  # Gradio UI
55
  with gr.Blocks() as demo:
 
61
  lines=3
62
  )
63
 
64
+ max_length = gr.Slider(50, 1000, value=300, step=50, label="Story Length in new tokens")
65
  temperature = gr.Slider(0.1, 1.5, value=0.8, step=0.1, label="Creativity")
66
 
67
  generate_btn = gr.Button("✨ Generate Story")