| |
| |
| |
| |
|
|
| import os, json, gc, tempfile, uuid, textwrap, traceback, sys |
| from pathlib import Path |
|
|
| import torch, gradio as gr, markdown |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
| |
| try: |
| from weasyprint import HTML |
| WEASYPRINT_OK = True |
| except ImportError: |
| WEASYPRINT_OK = False |
|
|
| |
| HF_TOKEN = os.getenv("HF_TOKEN") |
| if not HF_TOKEN: |
| raise RuntimeError( |
| "HF_TOKEN not found—add it to your Space secrets or export it locally." |
| ) |
|
|
| |
| torch.set_default_dtype(torch.float16) |
|
|
| |
| MODEL_ID = os.getenv("MODEL_ID", "Qwen/Qwen1.5-14B-Chat") |
| CACHE_DIR = Path("./qwen-cache") |
| CACHE_DIR.mkdir(exist_ok=True) |
|
|
| print(f"\n→ Loading {MODEL_ID}… this may take a minute on first run.\n") |
| tokenizer = AutoTokenizer.from_pretrained( |
| MODEL_ID, cache_dir=str(CACHE_DIR), token=HF_TOKEN, trust_remote_code=True |
| ) |
| model = AutoModelForCausalLM.from_pretrained( |
| MODEL_ID, |
| cache_dir=str(CACHE_DIR), |
| token=HF_TOKEN, |
| torch_dtype=torch.float16, |
| device_map="auto", |
| trust_remote_code=True |
| ).eval() |
|
|
| HISTORY_FILE = CACHE_DIR / "chat_history.json" |
|
|
| def chat_fn(message, history): |
| """Gradio ChatInterface callback with short persistent history.""" |
| msgs = [{"role": "system", "content": "You are Obsidian, a helpful AI assistant."}] |
| for u, a in history: |
| msgs.append({"role": "user", "content": u}) |
| msgs.append({"role": "assistant", "content": a}) |
| msgs.append({"role": "user", "content": message}) |
|
|
| prompt = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True) |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
|
| with torch.inference_mode(): |
| out = model.generate( |
| inputs.input_ids, |
| max_new_tokens=2048, |
| temperature=0.7, |
| top_p=0.9, |
| eos_token_id=tokenizer.eos_token_id |
| ) |
|
|
| reply = tokenizer.decode(out[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True).strip() |
| history.append((message, reply)) |
|
|
| |
| try: |
| HISTORY_FILE.write_text(json.dumps(history[-20:], ensure_ascii=False)) |
| except: |
| pass |
|
|
| |
| del inputs, out |
| torch.cuda.empty_cache(); gc.collect() |
| return reply |
|
|
| |
| BOOK_SECTION_BLUEPRINT = [ |
| {"title": "Cover", "words": 100}, |
| {"title": "Dedication", "words": 200}, |
| {"title": "Preface", "words": 2000}, |
| {"title": "Introduction", "words": 1500}, |
| ] + [ |
| {"title": f"Chapter {i+1}", "words": 2000} for i in range(25) |
| ] + [ |
| {"title": "Conclusion", "words": 1500}, |
| {"title": "Afterword", "words": 1000}, |
| {"title": "Appendix", "words": 1200}, |
| {"title": "Index", "words": 800}, |
| {"title": "Back Cover", "words": 100}, |
| ] |
|
|
| def _one_pass(msgs, max_tokens): |
| """Single-generation call.""" |
| prompt = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True) |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
| with torch.inference_mode(): |
| out = model.generate( |
| inputs.input_ids, |
| max_new_tokens=max_tokens, |
| temperature=0.7, |
| top_p=0.95, |
| repetition_penalty=1.1, |
| eos_token_id=tokenizer.eos_token_id |
| ) |
| text = tokenizer.decode(out[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True) |
| del inputs, out; torch.cuda.empty_cache(); gc.collect() |
| return text.strip() |
|
|
| def _generate_section(title: str, theme: str, target_words: int, max_passes: int = 8) -> str: |
| """Loop until we hit the word budget.""" |
| seed = ( |
| f"Write the '{title}' section for a book about “{theme}”. " |
| f"Aim for about {target_words} words, engaging style, no markdown headings." |
| ) |
| msgs = [{"role": "system", "content": "You are Obsidian, a professional book author."}, |
| {"role": "user", "content": seed}] |
| text, passes = "", 0 |
| while len(text.split()) < target_words and passes < max_passes: |
| rem = target_words - len(text.split()) |
| if passes > 0: |
| msgs.append({ |
| "role": "user", |
| "content": f"Continue the text; roughly {rem} more words needed. Avoid repetition." |
| }) |
| chunk = _one_pass(msgs, max_tokens=int(rem * 1.6)) |
| text += ("" if not text else " ") + chunk |
| passes += 1 |
| return textwrap.fill(text.strip(), 90) |
|
|
| CSS_STYLE = """ |
| <style> |
| body { font-family: 'Georgia', serif; font-size: 12pt; line-height: 1.6; margin: 2.5cm; } |
| h1 { font-size: 24pt; margin: 0 0 0.5em; page-break-after: avoid; } |
| p { margin: 0 0 0.8em; text-align: justify; } |
| </style> |
| """ |
|
|
| def generate_book(theme: str) -> str: |
| """Assemble HTML and render PDF for a 300-page book.""" |
| if not WEASYPRINT_OK: |
| raise RuntimeError("WeasyPrint import failed. Ensure Cairo & Pango are installed.") |
| tmp = Path(tempfile.mkdtemp(prefix="book_")) |
| parts = [] |
| for sec in BOOK_SECTION_BLUEPRINT: |
| print(f"Generating {sec['title']} ({sec['words']} words)…") |
| content = _generate_section(sec['title'], theme, sec['words']) |
| parts.append(f"<h1>{sec['title']}</h1><p>{content}</p>") |
| html = ( |
| "<html><head><meta charset='UTF-8'>" + CSS_STYLE + "</head><body>" |
| + "<div style='page-break-after:always'></div>".join(parts) |
| + "</body></html>" |
| ) |
| html_path = tmp / "book.html" |
| pdf_path = tmp / f"book_{uuid.uuid4()}.pdf" |
| html_path.write_text(html, encoding="utf-8") |
| HTML(str(html_path)).write_pdf(str(pdf_path)) |
| return str(pdf_path) |
|
|
| with gr.Blocks(theme="soft") as demo: |
| gr.Markdown("# 🪄 Obsidian Chat & 300-Page Book Maker") |
| with gr.Tab("Chat"): |
| gr.ChatInterface(fn=chat_fn, title=None) |
| with gr.Tab("Book"): |
| topic = gr.Textbox(label="Book theme", placeholder="e.g. A steampunk adventure") |
| run_btn = gr.Button("Generate 300+ page book") |
| status = gr.Textbox(label="Progress", interactive=False) |
| pdf_out = gr.File(label="Download your PDF") |
| def run(theme): |
| yield gr.update(value="⏳ Generating – please wait…"), None |
| try: |
| pdf = generate_book(theme) |
| yield gr.update(value="✅ Done! Click to download."), pdf |
| except Exception as e: |
| traceback.print_exc(file=sys.stderr) |
| yield gr.update(value=f"⚠️ Failed: {e}"), None |
| run_btn.click(run, inputs=topic, outputs=[status, pdf_out]) |
|
|
| if __name__ == "__main__": |
| demo.launch() |