CHUNYU0505 commited on
Commit
3bcec19
·
verified ·
1 Parent(s): 9044d56

多模型下拉選單

Browse files
Files changed (1) hide show
  1. app.py +16 -8
app.py CHANGED
@@ -52,17 +52,16 @@ retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 5})
52
  # -------------------------------
53
  # 4. 定義 REST API 呼叫函數
54
  # -------------------------------
55
- INFERENCE_MODEL = "google/flan-t5-xl"
56
- API_URL = f"https://api-inference.huggingface.co/models/{INFERENCE_MODEL}"
57
  HEADERS = {"Authorization": f"Bearer {HF_TOKEN}"}
58
 
59
- def call_hf_inference(prompt, max_new_tokens=512):
 
60
  payload = {
61
  "inputs": prompt,
62
  "parameters": {"max_new_tokens": max_new_tokens}
63
  }
64
  try:
65
- response = requests.post(API_URL, headers=HEADERS, json=payload, timeout=60)
66
  response.raise_for_status()
67
  data = response.json()
68
  if isinstance(data, list) and "generated_text" in data[0]:
@@ -90,7 +89,7 @@ def get_hf_rate_limit():
90
  # -------------------------------
91
  # 6. 生成文章(即時進度)
92
  # -------------------------------
93
- def generate_article_progress(query, segments=5):
94
  docx_file = "/tmp/generated_article.docx"
95
  doc = DocxDocument()
96
  doc.add_heading(query, level=1)
@@ -99,12 +98,11 @@ def generate_article_progress(query, segments=5):
99
  prompt = f"請依據下列主題生成段落:{query}\n\n每段約150-200字。"
100
 
101
  for i in range(int(segments)):
102
- paragraph = call_hf_inference(prompt)
103
  all_text.append(paragraph)
104
  doc.add_paragraph(paragraph)
105
  prompt = f"請接續上一段生成下一段:\n{paragraph}\n\n下一段:"
106
 
107
- # yield 即時更新 Textbox
108
  yield "\n\n".join(all_text), None
109
 
110
  doc.save(docx_file)
@@ -119,6 +117,16 @@ with gr.Blocks() as demo:
119
  gr.Markdown("使用 Hugging Face REST API + FAISS RAG,生成文章並提示 API 剩餘額度。")
120
 
121
  query_input = gr.Textbox(lines=2, placeholder="請輸入文章主題", label="文章主題")
 
 
 
 
 
 
 
 
 
 
122
  segments_input = gr.Slider(minimum=1, maximum=10, step=1, value=5, label="段落數")
123
  output_text = gr.Textbox(label="生成文章 + API 剩餘次數")
124
  output_file = gr.File(label="下載 DOCX")
@@ -126,7 +134,7 @@ with gr.Blocks() as demo:
126
  btn = gr.Button("生成文章")
127
  btn.click(
128
  generate_article_progress,
129
- inputs=[query_input, segments_input],
130
  outputs=[output_text, output_file]
131
  )
132
 
 
52
  # -------------------------------
53
  # 4. 定義 REST API 呼叫函數
54
  # -------------------------------
 
 
55
  HEADERS = {"Authorization": f"Bearer {HF_TOKEN}"}
56
 
57
+ def call_hf_inference(model_name, prompt, max_new_tokens=512):
58
+ api_url = f"https://api-inference.huggingface.co/models/{model_name}"
59
  payload = {
60
  "inputs": prompt,
61
  "parameters": {"max_new_tokens": max_new_tokens}
62
  }
63
  try:
64
+ response = requests.post(api_url, headers=HEADERS, json=payload, timeout=60)
65
  response.raise_for_status()
66
  data = response.json()
67
  if isinstance(data, list) and "generated_text" in data[0]:
 
89
  # -------------------------------
90
  # 6. 生成文章(即時進度)
91
  # -------------------------------
92
+ def generate_article_progress(query, model_name, segments=5):
93
  docx_file = "/tmp/generated_article.docx"
94
  doc = DocxDocument()
95
  doc.add_heading(query, level=1)
 
98
  prompt = f"請依據下列主題生成段落:{query}\n\n每段約150-200字。"
99
 
100
  for i in range(int(segments)):
101
+ paragraph = call_hf_inference(model_name, prompt)
102
  all_text.append(paragraph)
103
  doc.add_paragraph(paragraph)
104
  prompt = f"請接續上一段生成下一段:\n{paragraph}\n\n下一段:"
105
 
 
106
  yield "\n\n".join(all_text), None
107
 
108
  doc.save(docx_file)
 
117
  gr.Markdown("使用 Hugging Face REST API + FAISS RAG,生成文章並提示 API 剩餘額度。")
118
 
119
  query_input = gr.Textbox(lines=2, placeholder="請輸入文章主題", label="文章主題")
120
+ model_dropdown = gr.Dropdown(
121
+ choices=[
122
+ "gpt2",
123
+ "EleutherAI/gpt-neo-2.7B",
124
+ "EleutherAI/gpt-j-6B",
125
+ "facebook/bart-large-cnn"
126
+ ],
127
+ value="EleutherAI/gpt-neo-2.7B",
128
+ label="選擇生成模型"
129
+ )
130
  segments_input = gr.Slider(minimum=1, maximum=10, step=1, value=5, label="段落數")
131
  output_text = gr.Textbox(label="生成文章 + API 剩餘次數")
132
  output_file = gr.File(label="下載 DOCX")
 
134
  btn = gr.Button("生成文章")
135
  btn.click(
136
  generate_article_progress,
137
+ inputs=[query_input, model_dropdown, segments_input],
138
  outputs=[output_text, output_file]
139
  )
140