CHUNYU0505 commited on
Commit
c4310e4
·
verified ·
1 Parent(s): e1aabb3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -45
app.py CHANGED
@@ -1,16 +1,16 @@
1
  # app.py
2
- import os, torch
3
  from langchain.docstore.document import Document
4
  from langchain.text_splitter import RecursiveCharacterTextSplitter
5
  from langchain_community.vectorstores import FAISS
6
  from langchain_huggingface import HuggingFaceEmbeddings
7
  from docx import Document as DocxDocument
8
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
9
  from huggingface_hub import login, snapshot_download
10
  import gradio as gr
11
 
12
  # -------------------------------
13
- # 0. 載入向量資料庫
14
  # -------------------------------
15
  EMBEDDINGS_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
16
  embeddings_model = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL_NAME)
@@ -25,48 +25,25 @@ else:
25
  retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 5})
26
 
27
  # -------------------------------
28
- # 1. 模型設定(中文 GPT2 + fallback
29
  # -------------------------------
30
- PRIMARY_MODEL = "uer/gpt2-chinese-cluecorpusmedium"
31
- FALLBACK_MODEL = "uer/gpt2-chinese-cluecorpussmall"
32
 
33
  HF_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
34
  if HF_TOKEN:
35
  login(token=HF_TOKEN)
36
  print("✅ 已使用 HUGGINGFACEHUB_API_TOKEN 登入 Hugging Face")
37
 
38
- def try_download_model(repo_id):
39
- local_dir = f"./models/{repo_id.split('/')[-1]}"
40
- if not os.path.exists(local_dir):
41
- print(f"⬇️ 嘗試下載模型 {repo_id} ...")
42
- try:
43
- snapshot_download(repo_id=repo_id, token=HF_TOKEN, local_dir=local_dir)
44
- except Exception as e:
45
- print(f"⚠️ 模型 {repo_id} 無法下載: {e}")
46
- return None
47
- return local_dir
48
-
49
- LOCAL_MODEL_DIR = try_download_model(PRIMARY_MODEL)
50
- if LOCAL_MODEL_DIR is None:
51
- print("⚠️ 切換到 fallback 模型:小型 GPT2-Chinese")
52
- LOCAL_MODEL_DIR = try_download_model(FALLBACK_MODEL)
53
- MODEL_NAME = FALLBACK_MODEL
54
- else:
55
- MODEL_NAME = PRIMARY_MODEL
56
-
57
- print(f"👉 最終使用模型:{MODEL_NAME}")
58
 
59
- # -------------------------------
60
- # 2. pipeline 載入
61
- # -------------------------------
62
  tokenizer = AutoTokenizer.from_pretrained(LOCAL_MODEL_DIR)
63
- model = AutoModelForCausalLM.from_pretrained(LOCAL_MODEL_DIR)
64
-
65
- if tokenizer.pad_token is None:
66
- tokenizer.pad_token = tokenizer.eos_token
67
 
68
  generator = pipeline(
69
- "text-generation",
70
  model=model,
71
  tokenizer=tokenizer,
72
  device=-1 # CPU
@@ -74,21 +51,18 @@ generator = pipeline(
74
 
75
  def call_local_inference(prompt, max_new_tokens=256):
76
  try:
77
- if "中文" not in prompt:
78
- prompt += "\n(請用中文回答,且只依據提供的內容生成,不可加入其他知識)"
79
  outputs = generator(
80
  prompt,
81
  max_new_tokens=max_new_tokens,
82
- do_sample=True,
83
- temperature=0.7,
84
- pad_token_id=tokenizer.pad_token_id
85
  )
86
  return outputs[0]["generated_text"]
87
  except Exception as e:
88
  return f"(生成失敗:{e})"
89
 
90
  # -------------------------------
91
- # 3. 僅基於 RAG 的文章生成
92
  # -------------------------------
93
  def generate_article_rag_only(query, segments=3):
94
  docx_file = "/tmp/generated_article.docx"
@@ -103,8 +77,8 @@ def generate_article_rag_only(query, segments=3):
103
  context_texts = [d.page_content for d in retrieved_docs]
104
  full_context = "\n".join(context_texts)
105
 
106
- # 切分成小片段,避免模型超載
107
- splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=50)
108
  chunks = splitter.split_text(full_context)
109
 
110
  for i, chunk in enumerate(chunks[:segments]):
@@ -113,7 +87,7 @@ def generate_article_rag_only(query, segments=3):
113
  f"以下是唯一可用的參考內容:\n{chunk}\n\n"
114
  f"請基於這些內容,寫一段約150-200字的中文文章,"
115
  f"主題:{query}。\n"
116
- f"⚠️ 僅能使用參考內容,不可加入其他知識。"
117
  )
118
  paragraph = call_local_inference(prompt)
119
  all_text.append(paragraph)
@@ -129,11 +103,11 @@ def generate_article_rag_only(query, segments=3):
129
  yield "\n\n".join(all_text), docx_file, f"本次使用模型:{MODEL_NAME}", full_context, final_progress
130
 
131
  # -------------------------------
132
- # 4. Gradio 介面
133
  # -------------------------------
134
  with gr.Blocks() as demo:
135
  gr.Markdown("# 📺 電視弘法視頻生成文章RAG系統")
136
- gr.Markdown("只基於 faiss_db 內容生成文章,不加入外部知識。")
137
 
138
  query_input = gr.Textbox(lines=2, placeholder="請輸入文章主題", label="文章主題")
139
  segments_input = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="段落數")
 
1
  # app.py
2
+ import os
3
  from langchain.docstore.document import Document
4
  from langchain.text_splitter import RecursiveCharacterTextSplitter
5
  from langchain_community.vectorstores import FAISS
6
  from langchain_huggingface import HuggingFaceEmbeddings
7
  from docx import Document as DocxDocument
8
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
9
  from huggingface_hub import login, snapshot_download
10
  import gradio as gr
11
 
12
  # -------------------------------
13
+ # 0. 向量資料庫載入
14
  # -------------------------------
15
  EMBEDDINGS_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
16
  embeddings_model = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL_NAME)
 
25
  retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 5})
26
 
27
  # -------------------------------
28
+ # 1. 中文模型(Randeng-T5
29
  # -------------------------------
30
+ MODEL_NAME = "IDEA-CCNL/Randeng-T5-784M-Summary-Chinese"
 
31
 
32
  HF_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
33
  if HF_TOKEN:
34
  login(token=HF_TOKEN)
35
  print("✅ 已使用 HUGGINGFACEHUB_API_TOKEN 登入 Hugging Face")
36
 
37
+ LOCAL_MODEL_DIR = f"./models/{MODEL_NAME.split('/')[-1]}"
38
+ if not os.path.exists(LOCAL_MODEL_DIR):
39
+ print(f"⬇️ 嘗試下載模型 {MODEL_NAME} ...")
40
+ snapshot_download(repo_id=MODEL_NAME, token=HF_TOKEN, local_dir=LOCAL_MODEL_DIR)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
 
 
 
42
  tokenizer = AutoTokenizer.from_pretrained(LOCAL_MODEL_DIR)
43
+ model = AutoModelForSeq2SeqLM.from_pretrained(LOCAL_MODEL_DIR)
 
 
 
44
 
45
  generator = pipeline(
46
+ "text2text-generation",
47
  model=model,
48
  tokenizer=tokenizer,
49
  device=-1 # CPU
 
51
 
52
  def call_local_inference(prompt, max_new_tokens=256):
53
  try:
 
 
54
  outputs = generator(
55
  prompt,
56
  max_new_tokens=max_new_tokens,
57
+ do_sample=False, # 用摘要模型 → 不建議隨機取樣
58
+ temperature=0.7
 
59
  )
60
  return outputs[0]["generated_text"]
61
  except Exception as e:
62
  return f"(生成失敗:{e})"
63
 
64
  # -------------------------------
65
+ # 2. 基於 RAG 的文章生成
66
  # -------------------------------
67
  def generate_article_rag_only(query, segments=3):
68
  docx_file = "/tmp/generated_article.docx"
 
77
  context_texts = [d.page_content for d in retrieved_docs]
78
  full_context = "\n".join(context_texts)
79
 
80
+ # 切分 context,避免太長
81
+ splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
82
  chunks = splitter.split_text(full_context)
83
 
84
  for i, chunk in enumerate(chunks[:segments]):
 
87
  f"以下是唯一可用的參考內容:\n{chunk}\n\n"
88
  f"請基於這些內容,寫一段約150-200字的中文文章,"
89
  f"主題:{query}。\n"
90
+ f"⚠️ 僅能使用參考內容,不可加入外部知識。"
91
  )
92
  paragraph = call_local_inference(prompt)
93
  all_text.append(paragraph)
 
103
  yield "\n\n".join(all_text), docx_file, f"本次使用模型:{MODEL_NAME}", full_context, final_progress
104
 
105
  # -------------------------------
106
+ # 3. Gradio 介面
107
  # -------------------------------
108
  with gr.Blocks() as demo:
109
  gr.Markdown("# 📺 電視弘法視頻生成文章RAG系統")
110
+ gr.Markdown("只基於 faiss_db 內容生成中文文章。")
111
 
112
  query_input = gr.Textbox(lines=2, placeholder="請輸入文章主題", label="文章主題")
113
  segments_input = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="段落數")