Spaces:
Paused
Paused
Nattapong Tapachoom
Refactor app.py to improve model loading and PDF processing; update dataset generation logic and enhance UI components
084df26
| import os | |
| import io | |
| import re | |
| import json | |
| from datetime import datetime | |
| from typing import List, Dict, Any, Tuple | |
| import gradio as gr | |
| from pypdf import PdfReader | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| # โหลดโมเดลเริ่มต้น (default) | |
| DEFAULT_MODEL = "HuggingFaceH4/zephyr-7b-beta" | |
| # สร้าง pipeline global | |
| gen_pipe = None | |
| tokenizer = None | |
| current_model_id = None | |
| def load_model(model_id: str, hf_token: str = None): | |
| global gen_pipe, tokenizer, current_model_id | |
| if current_model_id == model_id and gen_pipe is not None: | |
| return gen_pipe | |
| print(f"🔄 Loading model: {model_id}") | |
| tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token) | |
| model = AutoModelForCausalLM.from_pretrained(model_id, token=hf_token, device_map="auto") | |
| gen_pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device_map="auto") | |
| current_model_id = model_id | |
| return gen_pipe | |
| def ensure_output_dir() -> str: | |
| outdir = os.path.join(os.getcwd(), "outputs") | |
| os.makedirs(outdir, exist_ok=True) | |
| return outdir | |
| def read_pdfs(files: List[gr.File]) -> Tuple[str, List[Dict[str, Any]]]: | |
| docs = [] | |
| combined_text_parts: List[str] = [] | |
| for f in files: | |
| path = f.name if hasattr(f, "name") else f | |
| reader = PdfReader(path) | |
| pages_text = [] | |
| for i, page in enumerate(reader.pages): | |
| text = page.extract_text() or "" | |
| text = re.sub(r"\s+", " ", text).strip() | |
| if text: | |
| pages_text.append({"page": i + 1, "text": text}) | |
| combined_text_parts.append(text) | |
| docs.append({"file": os.path.basename(path), "pages": pages_text}) | |
| combined_text = "\n\n".join(combined_text_parts) | |
| return combined_text, docs | |
| def chunk_text(text: str, chunk_size: int = 1500, overlap: int = 200, max_chunks: int = 5) -> List[str]: | |
| text = text.strip() | |
| if not text: | |
| return [] | |
| chunks: List[str] = [] | |
| start = 0 | |
| n = len(text) | |
| while start < n and len(chunks) < max_chunks: | |
| end = min(start + chunk_size, n) | |
| chunk = text[start:end] | |
| chunks.append(chunk) | |
| if end >= n: | |
| break | |
| start = max(end - overlap, 0) | |
| return chunks | |
| # เทมเพลต prompt พื้นฐาน | |
| DEFAULT_QA_PROMPT = ( | |
| "คุณเป็นผู้ช่วยสร้างชุดข้อมูล อ่านเนื้อหานี้แล้วสร้างคำถาม-คำตอบ " | |
| "จำนวน {min_pairs} ถึง {max_pairs} คู่ " | |
| "ส่งคืน JSON array ที่มี objects รูปแบบ {{\"question\": str, \"answer\": str}} เท่านั้น\n\n" | |
| "เนื้อหา:\n{content}\n" | |
| ) | |
| def generate_dataset(files: List[gr.File], | |
| task: str, | |
| preset_model: str, | |
| custom_model_id: str, | |
| hf_token: str, | |
| chunk_size: int, | |
| overlap: int, | |
| max_chunks: int, | |
| max_new_tokens: int, | |
| temperature: float, | |
| min_pairs: int, | |
| max_pairs: int): | |
| if not files: | |
| return "❌ กรุณาอัปโหลดไฟล์ PDF", None, None | |
| # โหลดโมเดล | |
| model_id = (custom_model_id or "").strip() or preset_model or DEFAULT_MODEL | |
| pipe = load_model(model_id, hf_token or None) | |
| # อ่าน PDF และตัดเป็น chunk | |
| full_text, _ = read_pdfs(files) | |
| chunks = chunk_text(full_text, chunk_size, overlap, max_chunks) | |
| if not chunks: | |
| return "❌ ไม่สามารถดึงข้อความจาก PDF", None, None | |
| results = [] | |
| for ch in chunks: | |
| prompt = DEFAULT_QA_PROMPT.format( | |
| min_pairs=min_pairs, | |
| max_pairs=max_pairs, | |
| content=ch | |
| ) | |
| output = pipe(prompt, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| do_sample=temperature > 0.0)[0]["generated_text"] | |
| # พยายาม extract JSON | |
| start, end = output.find("["), output.rfind("]") | |
| if start != -1 and end != -1: | |
| try: | |
| data = json.loads(output[start:end + 1]) | |
| if isinstance(data, list): | |
| results.extend(data) | |
| except Exception: | |
| pass | |
| if not results: | |
| return "❌ ไม่สามารถสร้างข้อมูล JSON ได้", None, None | |
| # Save output | |
| outdir = ensure_output_dir() | |
| ts = datetime.utcnow().strftime("%Y%m%d_%H%M%S") | |
| json_path = os.path.join(outdir, f"dataset_{task}_{ts}.json") | |
| jsonl_path = os.path.join(outdir, f"dataset_{task}_{ts}.jsonl") | |
| with io.open(json_path, "w", encoding="utf-8") as f: | |
| json.dump(results, f, ensure_ascii=False, indent=2) | |
| with io.open(jsonl_path, "w", encoding="utf-8") as f: | |
| for item in results: | |
| f.write(json.dumps(item, ensure_ascii=False) + "\n") | |
| return f"✅ สร้างข้อมูลสำเร็จ {len(results)} รายการ", json_path, jsonl_path | |
| # ---------------- Gradio UI ---------------- | |
| PRESET_MODELS = [ | |
| DEFAULT_MODEL, | |
| "mistralai/Mistral-7B-Instruct-v0.2", | |
| "meta-llama/Llama-2-7b-chat-hf", | |
| "google/flan-t5-large" | |
| ] | |
| with gr.Blocks(title="Thai PDF → Dataset Generator") as demo: | |
| gr.Markdown("# 📚 Thai Auto Dataset Generator") | |
| with gr.Row(): | |
| pdf_files = gr.File(label="อัปโหลด PDF", file_count="multiple", file_types=[".pdf"]) | |
| with gr.Row(): | |
| task = gr.Textbox(label="Task", value="QA") | |
| preset_model = gr.Dropdown(label="Preset Model", choices=PRESET_MODELS, value=DEFAULT_MODEL) | |
| custom_model_id = gr.Textbox(label="Custom Model ID", placeholder="org/model-name") | |
| hf_token = gr.Textbox(label="HF Token", type="password") | |
| with gr.Row(): | |
| max_new_tokens = gr.Slider(64, 1024, value=512, step=16, label="Max New Tokens") | |
| temperature = gr.Slider(0.0, 1.5, value=0.3, step=0.05, label="Temperature") | |
| with gr.Row(): | |
| chunk_size = gr.Slider(500, 4000, value=1500, step=50, label="Chunk Size") | |
| overlap = gr.Slider(0, 1000, value=200, step=50, label="Overlap") | |
| max_chunks = gr.Slider(1, 20, value=5, step=1, label="Max Chunks") | |
| with gr.Row(): | |
| min_pairs = gr.Slider(1, 10, value=3, step=1, label="Min Pairs") | |
| max_pairs = gr.Slider(1, 12, value=6, step=1, label="Max Pairs") | |
| generate_btn = gr.Button("🚀 Generate Dataset") | |
| status = gr.Markdown() | |
| out_json = gr.File(label="JSON") | |
| out_jsonl = gr.File(label="JSONL") | |
| generate_btn.click( | |
| fn=generate_dataset, | |
| inputs=[pdf_files, task, preset_model, custom_model_id, hf_token, | |
| chunk_size, overlap, max_chunks, max_new_tokens, temperature, | |
| min_pairs, max_pairs], | |
| outputs=[status, out_json, out_jsonl] | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch() |