sseymens's picture
Update app.py
c289d30 verified
# app.py
# """Obsidian Chat + 300-Page-Book Generator
# Deployable as a Gradio app on Hugging Face Spaces.
# Requires a HF read token in the secret HF_TOKEN and ≥24 GB VRAM."""
import os, json, gc, tempfile, uuid, textwrap, traceback, sys
from pathlib import Path
import torch, gradio as gr, markdown
from transformers import AutoTokenizer, AutoModelForCausalLM
# PDF backend
try:
from weasyprint import HTML
WEASYPRINT_OK = True
except ImportError:
WEASYPRINT_OK = False
# Hugging Face token
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
raise RuntimeError(
"HF_TOKEN not found—add it to your Space secrets or export it locally."
)
# Force FP16 (avoids BF16/FP32 mismatches)
torch.set_default_dtype(torch.float16)
# Model & cache
MODEL_ID = os.getenv("MODEL_ID", "Qwen/Qwen1.5-14B-Chat")
CACHE_DIR = Path("./qwen-cache")
CACHE_DIR.mkdir(exist_ok=True)
print(f"\n→ Loading {MODEL_ID}… this may take a minute on first run.\n")
tokenizer = AutoTokenizer.from_pretrained(
MODEL_ID, cache_dir=str(CACHE_DIR), token=HF_TOKEN, trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
cache_dir=str(CACHE_DIR),
token=HF_TOKEN,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
).eval()
HISTORY_FILE = CACHE_DIR / "chat_history.json"
def chat_fn(message, history):
"""Gradio ChatInterface callback with short persistent history."""
msgs = [{"role": "system", "content": "You are Obsidian, a helpful AI assistant."}]
for u, a in history:
msgs.append({"role": "user", "content": u})
msgs.append({"role": "assistant", "content": a})
msgs.append({"role": "user", "content": message})
prompt = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.inference_mode():
out = model.generate(
inputs.input_ids,
max_new_tokens=2048,
temperature=0.7,
top_p=0.9,
eos_token_id=tokenizer.eos_token_id
)
reply = tokenizer.decode(out[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True).strip()
history.append((message, reply))
# persist last 20 exchanges
try:
HISTORY_FILE.write_text(json.dumps(history[-20:], ensure_ascii=False))
except:
pass
# free GPU
del inputs, out
torch.cuda.empty_cache(); gc.collect()
return reply
# Book section blueprint (~300 pages ≈ 55,000 words)
BOOK_SECTION_BLUEPRINT = [
{"title": "Cover", "words": 100},
{"title": "Dedication", "words": 200},
{"title": "Preface", "words": 2000},
{"title": "Introduction", "words": 1500},
] + [
{"title": f"Chapter {i+1}", "words": 2000} for i in range(25)
] + [
{"title": "Conclusion", "words": 1500},
{"title": "Afterword", "words": 1000},
{"title": "Appendix", "words": 1200},
{"title": "Index", "words": 800},
{"title": "Back Cover", "words": 100},
]
def _one_pass(msgs, max_tokens):
"""Single-generation call."""
prompt = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.inference_mode():
out = model.generate(
inputs.input_ids,
max_new_tokens=max_tokens,
temperature=0.7,
top_p=0.95,
repetition_penalty=1.1,
eos_token_id=tokenizer.eos_token_id
)
text = tokenizer.decode(out[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True)
del inputs, out; torch.cuda.empty_cache(); gc.collect()
return text.strip()
def _generate_section(title: str, theme: str, target_words: int, max_passes: int = 8) -> str:
"""Loop until we hit the word budget."""
seed = (
f"Write the '{title}' section for a book about “{theme}”. "
f"Aim for about {target_words} words, engaging style, no markdown headings."
)
msgs = [{"role": "system", "content": "You are Obsidian, a professional book author."},
{"role": "user", "content": seed}]
text, passes = "", 0
while len(text.split()) < target_words and passes < max_passes:
rem = target_words - len(text.split())
if passes > 0:
msgs.append({
"role": "user",
"content": f"Continue the text; roughly {rem} more words needed. Avoid repetition."
})
chunk = _one_pass(msgs, max_tokens=int(rem * 1.6))
text += ("" if not text else " ") + chunk
passes += 1
return textwrap.fill(text.strip(), 90)
CSS_STYLE = """
<style>
body { font-family: 'Georgia', serif; font-size: 12pt; line-height: 1.6; margin: 2.5cm; }
h1 { font-size: 24pt; margin: 0 0 0.5em; page-break-after: avoid; }
p { margin: 0 0 0.8em; text-align: justify; }
</style>
"""
def generate_book(theme: str) -> str:
"""Assemble HTML and render PDF for a 300-page book."""
if not WEASYPRINT_OK:
raise RuntimeError("WeasyPrint import failed. Ensure Cairo & Pango are installed.")
tmp = Path(tempfile.mkdtemp(prefix="book_"))
parts = []
for sec in BOOK_SECTION_BLUEPRINT:
print(f"Generating {sec['title']} ({sec['words']} words)…")
content = _generate_section(sec['title'], theme, sec['words'])
parts.append(f"<h1>{sec['title']}</h1><p>{content}</p>")
html = (
"<html><head><meta charset='UTF-8'>" + CSS_STYLE + "</head><body>"
+ "<div style='page-break-after:always'></div>".join(parts)
+ "</body></html>"
)
html_path = tmp / "book.html"
pdf_path = tmp / f"book_{uuid.uuid4()}.pdf"
html_path.write_text(html, encoding="utf-8")
HTML(str(html_path)).write_pdf(str(pdf_path))
return str(pdf_path)
with gr.Blocks(theme="soft") as demo:
gr.Markdown("# 🪄 Obsidian Chat & 300-Page Book Maker")
with gr.Tab("Chat"):
gr.ChatInterface(fn=chat_fn, title=None)
with gr.Tab("Book"):
topic = gr.Textbox(label="Book theme", placeholder="e.g. A steampunk adventure")
run_btn = gr.Button("Generate 300+ page book")
status = gr.Textbox(label="Progress", interactive=False)
pdf_out = gr.File(label="Download your PDF")
def run(theme):
yield gr.update(value="⏳ Generating – please wait…"), None
try:
pdf = generate_book(theme)
yield gr.update(value="✅ Done! Click to download."), pdf
except Exception as e:
traceback.print_exc(file=sys.stderr)
yield gr.update(value=f"⚠️ Failed: {e}"), None
run_btn.click(run, inputs=topic, outputs=[status, pdf_out])
if __name__ == "__main__":
demo.launch()