network01 commited on
Commit
e7b16a5
·
verified ·
1 Parent(s): 75b10dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -25
app.py CHANGED
@@ -1,44 +1,91 @@
1
  import gradio as gr
2
- from transformers import pipeline
3
 
4
- # 使用一个轻量文本生成模型(CPU 可跑)
5
- generator = pipeline(
6
- "text-generation",
7
- model="distilgpt2"
8
- )
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- def rewrite(text, style):
11
  if style == "More formal":
12
- prompt = f"Rewrite this sentence in a more formal tone:\n{text}"
13
- elif style == "More friendly":
14
- prompt = f"Rewrite this sentence in a friendly tone:\n{text}"
15
- else:
16
- prompt = f"Rewrite this sentence to be shorter:\n{text}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- result = generator(prompt, max_length=100, num_return_sequences=1)
19
- return result[0]["generated_text"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  with gr.Blocks(title="AI Text Rewriter") as demo:
22
  gr.Markdown(
23
- "### ✨ AI Text Rewriter\n"
24
- "Paste a sentence and let AI rewrite it for you."
25
  )
26
 
 
 
 
 
 
 
 
27
  text_input = gr.Textbox(
28
  label="Your text",
29
- lines=3,
30
- placeholder="Type something here..."
31
  )
32
 
33
- style = gr.Radio(
34
- ["More formal", "More friendly", "Shorter"],
35
- label="Rewrite style",
36
- value="More friendly"
37
- )
38
 
39
- output = gr.Textbox(label="AI Result", lines=4)
40
 
41
- btn = gr.Button("Rewrite with AI")
42
  btn.click(fn=rewrite, inputs=[text_input, style], outputs=output)
 
43
 
44
  demo.launch()
 
1
  import gradio as gr
2
+ from functools import lru_cache
3
 
4
+ @lru_cache(maxsize=1)
5
+ def get_rewriter():
6
+ """
7
+ Use a lightweight instruction-following model that can run on CPU.
8
+ flan-t5-small is generally more reliable for "rewrite" than GPT2-style models.
9
+ """
10
+ from transformers import pipeline
11
+ return pipeline(
12
+ task="text2text-generation",
13
+ model="google/flan-t5-small",
14
+ device=-1, # CPU
15
+ )
16
+
17
+ def build_prompt(text: str, style: str) -> str:
18
+ text = (text or "").strip()
19
+ if not text:
20
+ return ""
21
 
 
22
  if style == "More formal":
23
+ return (
24
+ "Rewrite the text in a more formal tone. "
25
+ "Keep the original meaning. Output only the rewritten text.\n\n"
26
+ f"Text: {text}"
27
+ )
28
+ if style == "More friendly":
29
+ return (
30
+ "Rewrite the text in a friendly, warm tone. "
31
+ "Keep the original meaning. Output only the rewritten text.\n\n"
32
+ f"Text: {text}"
33
+ )
34
+ return (
35
+ "Rewrite the text to be shorter and clearer. "
36
+ "Keep the original meaning. Output only the rewritten text.\n\n"
37
+ f"Text: {text}"
38
+ )
39
+
40
+ def rewrite(text: str, style: str) -> str:
41
+ text = (text or "").strip()
42
+ if not text:
43
+ return "Please enter some text."
44
 
45
+ prompt = build_prompt(text, style)
46
+
47
+ try:
48
+ rewriter = get_rewriter()
49
+ out = rewriter(
50
+ prompt,
51
+ max_new_tokens=128,
52
+ do_sample=False,
53
+ )
54
+ result = (out[0].get("generated_text") or "").strip()
55
+ return result if result else "No output. Try a shorter input."
56
+ except Exception as e:
57
+ return (
58
+ "Error: failed to run the model.\n"
59
+ "If this is the first run, the Space may still be downloading the model.\n\n"
60
+ f"Details: {type(e).__name__}: {e}"
61
+ )
62
 
63
  with gr.Blocks(title="AI Text Rewriter") as demo:
64
  gr.Markdown(
65
+ "AI Text Rewriter\n"
66
+ "Paste a sentence or short paragraph, choose a style, then rewrite with AI."
67
  )
68
 
69
+ with gr.Row():
70
+ style = gr.Radio(
71
+ ["More formal", "More friendly", "Shorter"],
72
+ value="More friendly",
73
+ label="Rewrite style",
74
+ )
75
+
76
  text_input = gr.Textbox(
77
  label="Your text",
78
+ placeholder="Type or paste text here...",
79
+ lines=5,
80
  )
81
 
82
+ with gr.Row():
83
+ btn = gr.Button("Rewrite with AI")
84
+ clear = gr.Button("Clear")
 
 
85
 
86
+ output = gr.Textbox(label="Result", lines=6)
87
 
 
88
  btn.click(fn=rewrite, inputs=[text_input, style], outputs=output)
89
+ clear.click(fn=lambda: ("", ""), inputs=None, outputs=[text_input, output])
90
 
91
  demo.launch()