CHUNYU0505 commited on
Commit
052f25a
·
verified ·
1 Parent(s): 8d688b4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -59
app.py CHANGED
@@ -5,92 +5,74 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
5
  from langchain_community.vectorstores import FAISS
6
  from langchain_huggingface import HuggingFaceEmbeddings
7
  from docx import Document as DocxDocument
8
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
9
  from huggingface_hub import login, snapshot_download
10
  import gradio as gr
11
 
12
  # -------------------------------
13
- # 1. 模型設定(中文 T5 + fallback
14
  # -------------------------------
15
- PRIMARY_MODEL = "Langboat/mengzi-t5-base" # ✅ spiece.model
16
- FALLBACK_MODEL = "uer/t5-small-chinese-cluecorpussmall"
17
 
18
  HF_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
19
  if HF_TOKEN:
20
  login(token=HF_TOKEN)
21
  print("✅ 已使用 HUGGINGFACEHUB_API_TOKEN 登入 Hugging Face")
22
 
23
- def try_download_model(repo_id):
24
- local_dir = f"./models/{repo_id.split('/')[-1]}"
25
- if not os.path.exists(local_dir):
26
- print(f"⬇️ 嘗試下載模型 {repo_id} ...")
27
- try:
28
- snapshot_download(repo_id=repo_id, token=HF_TOKEN, local_dir=local_dir)
29
- except Exception as e:
30
- print(f"⚠️ 模型 {repo_id} 無法下載: {e}")
31
- return None
32
- return local_dir
33
-
34
- # 嘗試下載 Primary,失敗就換 Small
35
- LOCAL_MODEL_DIR = try_download_model(PRIMARY_MODEL)
36
- if LOCAL_MODEL_DIR is None:
37
- print("⚠️ 切換到 fallback 模型:小型 T5-Chinese")
38
- LOCAL_MODEL_DIR = try_download_model(FALLBACK_MODEL)
39
- MODEL_NAME = FALLBACK_MODEL
40
- else:
41
- MODEL_NAME = PRIMARY_MODEL
42
 
43
  print(f"👉 最終使用模型:{MODEL_NAME}")
44
 
45
  # -------------------------------
46
- # 2. pipeline 載入 (Seq2SeqLM for T5)
47
  # -------------------------------
48
  tokenizer = AutoTokenizer.from_pretrained(LOCAL_MODEL_DIR)
49
- model = AutoModelForSeq2SeqLM.from_pretrained(LOCAL_MODEL_DIR)
50
-
51
- generator = pipeline(
52
- "text2text-generation",
53
- model=model,
54
- tokenizer=tokenizer,
55
- device=-1 # CPU
56
- )
57
-
58
- def call_local_inference(prompt, max_new_tokens=256):
59
- try:
60
- outputs = generator(
61
- prompt,
62
- max_new_tokens=max_new_tokens,
63
- do_sample=True,
64
- temperature=0.7
65
- )
66
- return outputs[0]["generated_text"]
67
- except Exception as e:
68
- return f"(生成失敗:{e})"
69
 
70
  # -------------------------------
71
- # 3. 建立/載入向量資料庫
72
  # -------------------------------
73
  EMBEDDINGS_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
74
  embeddings_model = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL_NAME)
75
 
76
- DB_PATH = "./faiss_db"
77
- if os.path.exists(os.path.join(DB_PATH, "index.faiss")):
78
  print("✅ 載入現有向量資料庫...")
79
- db = FAISS.load_local(DB_PATH, embeddings_model, allow_dangerous_deserialization=True)
80
  else:
81
- print("⚠️ 沒有找到資料庫,請先建立 faiss_db")
82
  db = None
83
 
84
- retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 3}) if db else None
85
 
86
  # -------------------------------
87
- # 4. 文章生成(加入 RAG)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  # -------------------------------
89
  def generate_article_progress(query, segments=5):
90
  docx_file = "/tmp/generated_article.docx"
91
  doc = DocxDocument()
92
  doc.add_heading(query, level=1)
93
-
94
  all_text = []
95
 
96
  context = ""
@@ -101,30 +83,34 @@ def generate_article_progress(query, segments=5):
101
 
102
  for i in range(segments):
103
  prompt = (
104
- f"以下是佛教經論的相關段落:\n{context}\n\n"
105
- f"請依據上面內容,寫一段約150-200字的中文文章,"
106
- f"主題:{query}。\n第{i+1}段:"
 
 
107
  )
 
108
  paragraph = call_local_inference(prompt)
109
  all_text.append(paragraph)
110
  doc.add_paragraph(paragraph)
 
111
  yield "\n\n".join(all_text), None, f"本次使用模型:{MODEL_NAME}"
112
 
113
  doc.save(docx_file)
114
  yield "\n\n".join(all_text), docx_file, f"本次使用模型:{MODEL_NAME}"
115
 
116
  # -------------------------------
117
- # 5. Gradio 介面
118
  # -------------------------------
119
  with gr.Blocks() as demo:
120
  gr.Markdown("# 📺 電視弘法視頻生成文章 RAG 系統")
121
- gr.Markdown("使用 FAISS + 中文 T5 模型,基於資料庫內容生成文章。")
122
 
123
  query_input = gr.Textbox(lines=2, placeholder="請輸入文章主題", label="文章主題")
124
- segments_input = gr.Slider(minimum=1, maximum=10, step=1, value=5, label="段落數")
125
  output_text = gr.Textbox(label="生成文章")
126
  output_file = gr.File(label="下載 DOCX")
127
- model_info = gr.Label(label="模型資訊")
128
 
129
  btn = gr.Button("生成文章")
130
  btn.click(
 
5
  from langchain_community.vectorstores import FAISS
6
  from langchain_huggingface import HuggingFaceEmbeddings
7
  from docx import Document as DocxDocument
8
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
9
  from huggingface_hub import login, snapshot_download
10
  import gradio as gr
11
 
12
  # -------------------------------
13
+ # 1. 模型設定(專門中文,T5)
14
  # -------------------------------
15
+ MODEL_NAME = "Langboat/mengzi-t5-base" # ✅ CPU 也能跑的中文 T5
16
+ LOCAL_MODEL_DIR = f"./models/{MODEL_NAME.split('/')[-1]}"
17
 
18
  HF_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
19
  if HF_TOKEN:
20
  login(token=HF_TOKEN)
21
  print("✅ 已使用 HUGGINGFACEHUB_API_TOKEN 登入 Hugging Face")
22
 
23
+ if not os.path.exists(LOCAL_MODEL_DIR):
24
+ print(f"⬇️ 嘗試下載模型 {MODEL_NAME} ...")
25
+ snapshot_download(repo_id=MODEL_NAME, token=HF_TOKEN, local_dir=LOCAL_MODEL_DIR)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  print(f"👉 最終使用模型:{MODEL_NAME}")
28
 
29
  # -------------------------------
30
+ # 2. 載入 tokenizer + model
31
  # -------------------------------
32
  tokenizer = AutoTokenizer.from_pretrained(LOCAL_MODEL_DIR)
33
+ model = AutoModelForSeq2SeqLM.from_pretrained(LOCAL_MODEL_DIR, device_map="cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  # -------------------------------
36
+ # 3. 向量資料庫載入
37
  # -------------------------------
38
  EMBEDDINGS_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
39
  embeddings_model = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL_NAME)
40
 
41
+ if os.path.exists("./faiss_db/index.faiss"):
 
42
  print("✅ 載入現有向量資料庫...")
43
+ db = FAISS.load_local("./faiss_db", embeddings_model, allow_dangerous_deserialization=True)
44
  else:
45
+ print("⚠️ 找不到向量資料庫,請先建立")
46
  db = None
47
 
48
+ retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 5}) if db else None
49
 
50
  # -------------------------------
51
+ # 4. 改良推理函數(避免重複亂碼)
52
+ # -------------------------------
53
+ def call_local_inference(prompt, max_new_tokens=256):
54
+ try:
55
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(model.device)
56
+ outputs = model.generate(
57
+ **inputs,
58
+ max_new_tokens=max_new_tokens,
59
+ do_sample=False, # ❌ 關掉隨機
60
+ num_beams=4, # ✅ 用 beam search
61
+ repetition_penalty=1.2, # ✅ 懲罰重複
62
+ length_penalty=1.0,
63
+ early_stopping=True
64
+ )
65
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
66
+ except Exception as e:
67
+ return f"(生成失敗:{e})"
68
+
69
+ # -------------------------------
70
+ # 5. 文章生成(加入 RAG)
71
  # -------------------------------
72
  def generate_article_progress(query, segments=5):
73
  docx_file = "/tmp/generated_article.docx"
74
  doc = DocxDocument()
75
  doc.add_heading(query, level=1)
 
76
  all_text = []
77
 
78
  context = ""
 
83
 
84
  for i in range(segments):
85
  prompt = (
86
+ f"請基於以下資料,撰寫一段中文文章:\n"
87
+ f"主題:{query}\n"
88
+ f"要求:字數約150~200字,內容要有完整句子,不要重複詞語。\n\n"
89
+ f"參考資料:\n{context}\n\n"
90
+ f"第{i+1}段:"
91
  )
92
+
93
  paragraph = call_local_inference(prompt)
94
  all_text.append(paragraph)
95
  doc.add_paragraph(paragraph)
96
+
97
  yield "\n\n".join(all_text), None, f"本次使用模型:{MODEL_NAME}"
98
 
99
  doc.save(docx_file)
100
  yield "\n\n".join(all_text), docx_file, f"本次使用模型:{MODEL_NAME}"
101
 
102
  # -------------------------------
103
+ # 6. Gradio 介面
104
  # -------------------------------
105
  with gr.Blocks() as demo:
106
  gr.Markdown("# 📺 電視弘法視頻生成文章 RAG 系統")
107
+ gr.Markdown("基於向量資料庫 + 中文 T5 模型,自動生成主題文章")
108
 
109
  query_input = gr.Textbox(lines=2, placeholder="請輸入文章主題", label="文章主題")
110
+ segments_input = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="段落數")
111
  output_text = gr.Textbox(label="生成文章")
112
  output_file = gr.File(label="下載 DOCX")
113
+ model_info = gr.Textbox(label="模型資訊")
114
 
115
  btn = gr.Button("生成文章")
116
  btn.click(