Spaces:
Sleeping
Sleeping
| import os | |
| import tempfile | |
| # Set Streamlit config paths to writable temp directory | |
| # This prevents permission errors on HF Spaces | |
| temp_dir = tempfile.gettempdir() | |
| os.environ['STREAMLIT_SERVER_FILE_WATCHER_TYPE'] = 'none' | |
| os.environ['STREAMLIT_BROWSER_GATHER_USAGE_STATS'] = 'false' | |
| os.environ['STREAMLIT_THEME_BASE'] = 'light' | |
| os.environ['HOME'] = temp_dir | |
| import json | |
| from typing import Dict, Any, List | |
| import uuid | |
| import streamlit as st | |
| from transformers import ( | |
| AutoTokenizer, AutoModelForSequenceClassification, | |
| AutoModelForCausalLM, AutoModelForSeq2SeqLM, | |
| pipeline | |
| ) | |
| import torch | |
| st.set_page_config( | |
| page_title="Arabic Poetry Lab – Meters, Diacritization & Generation", | |
| page_icon="🕊️", | |
| layout="wide" | |
| ) | |
| # ----------------------------- | |
| # Model Registry (edit safely) | |
| # ----------------------------- | |
| MODEL_REGISTRY = { | |
| # === Meter classification models === | |
| "AraPoemBERT (meter)": { | |
| "task": "text-classification", | |
| "repo": "faisalq/bert-base-arapoembert", | |
| "paper": "AraPoemBERT (Qarah, 2024)", | |
| "notes": "BERT-based poetry LM, SOTA on meter/sub-meter/rhyme tasks." | |
| }, | |
| "AraGPT2 (base, Arabic)": { | |
| "task": "text-generation", | |
| "repo": "aubmindlab/aragpt2-base", | |
| "paper": "Antoun et al. (AraGPT2)", | |
| "notes": "Use with prompts that include meter/rhyme hints." | |
| }, | |
| } | |
| HELP_TEXT = """ | |
| ### What this Space does | |
| This app lets you **try Arabic poetry models** from the literature: | |
| - **Meter classification** (text) – predict the baḥr class. | |
| - **Era / Theme classification** (text) – Ashaar suite classifiers. | |
| - **Diacritization** – undiacritized → diacritized verse (seq2seq). | |
| - **Poetry generation** – prompt a model to continue a verse with target meter / rhyme / theme hints. | |
| > 🔧 **Tip**: For any entry with an empty model repo, paste the exact Hugging Face repo ID (e.g., `faisalq/AraPoemBERT-meter`). You can add your own models too. | |
| """ | |
| # ----------------------------- | |
| # Caching model pipelines | |
| # ----------------------------- | |
| def get_pipeline(task: str, model_id: str): | |
| """Load model pipeline with free tier optimizations""" | |
| try: | |
| # Check if GPU is available, but don't force it | |
| device = 0 if torch.cuda.is_available() else -1 | |
| if task == "text-classification": | |
| return pipeline( | |
| "text-classification", | |
| model=model_id, | |
| tokenizer=model_id, | |
| device=device, | |
| top_k=None | |
| ) | |
| elif task == "text2text-generation": | |
| return pipeline( | |
| "text2text-generation", | |
| model=model_id, | |
| tokenizer=model_id, | |
| device=device | |
| ) | |
| elif task == "text-generation": | |
| # For generation models, use smaller precision on free tier | |
| return pipeline( | |
| "text-generation", | |
| model=model_id, | |
| tokenizer=model_id, | |
| device=device, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| low_cpu_mem_usage=True | |
| ) | |
| elif task == "fill-mask": | |
| return pipeline( | |
| "fill-mask", | |
| model=model_id, | |
| tokenizer=model_id, | |
| device=device | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported task: {task}") | |
| except Exception as e: | |
| st.error(f"Error loading model: {str(e)}") | |
| raise | |
| def section_header(title, emoji="✨"): | |
| st.markdown(f"## {emoji} {title}") | |
| def model_picker(task_filter: str, context: str = "") -> Dict[str, Any]: | |
| """Model selection widget with unique keys""" | |
| subset = {k: v for k, v in MODEL_REGISTRY.items() if v["task"] == task_filter} | |
| names = list(subset.keys()) | |
| # Create unique key suffix | |
| unique_suffix = f"{context}_{task_filter}_{uuid.uuid4().hex[:8]}" | |
| if not names: | |
| st.warning(f"No models registered for task: {task_filter}") | |
| st.info("You can add a custom model repo ID below.") | |
| repo = st.text_input( | |
| "Model repo on Hugging Face", | |
| placeholder="org/model-id", | |
| key=f"repo_custom_{unique_suffix}" | |
| ) | |
| return { | |
| "name": "Custom", | |
| "task": task_filter, | |
| "repo": repo, | |
| "paper": "N/A", | |
| "notes": "Custom model" | |
| } | |
| choice = st.selectbox( | |
| "Pick a model", | |
| names, | |
| key=f"picker_{unique_suffix}" | |
| ) | |
| cfg = subset[choice] | |
| repo = st.text_input( | |
| "Model repo on Hugging Face", | |
| value=cfg["repo"], | |
| placeholder="org/model-id", | |
| key=f"repo_{unique_suffix}" | |
| ) | |
| st.caption(f"**Paper**: {cfg['paper']} \n**Notes**: {cfg['notes']}") | |
| return { | |
| "name": choice, | |
| "task": cfg["task"], | |
| "repo": repo, | |
| "paper": cfg["paper"], | |
| "notes": cfg["notes"] | |
| } | |
| # ----------------------------- | |
| # Sidebar | |
| # ----------------------------- | |
| with st.sidebar: | |
| st.title("Arabic Poetry Lab") | |
| st.info("Plug your model repo IDs, then run 🔽") | |
| st.markdown(HELP_TEXT) | |
| st.markdown("---") | |
| st.markdown("**Quick admin**") | |
| show_raw = st.checkbox("Show raw HF output", value=False) | |
| st.caption("Raw = full JSON from transformers pipeline") | |
| st.title("🕊️ Arabic Poetry Lab on HF") | |
| st.write("Try meter classifiers, diacritizers, and generators from the literature.") | |
| tabs = st.tabs([ | |
| "Meter classification", | |
| "Era / Theme classification", | |
| "Diacritization", | |
| "Poetry generation", | |
| "Instructions" | |
| ]) | |
| # ----------------------------- | |
| # Tab 1: Meter classification | |
| # ----------------------------- | |
| with tabs[0]: | |
| section_header("Meter classification (text)", "📏") | |
| cfg = model_picker("text-classification", context="meter") | |
| verse = st.text_area( | |
| "Paste a single bayt (verse) or hemistich", | |
| height=120, | |
| placeholder="اكتب البيت هنا ...", | |
| key="meter_verse" | |
| ) | |
| topk = st.slider("Top-k labels to show", 1, 16, 5, key="meter_topk") | |
| if st.button("Classify meter", type="primary", key="classify_meter"): | |
| if not cfg.get("repo") or not verse.strip(): | |
| st.warning("Please provide both a model repo and input text.") | |
| else: | |
| with st.spinner("Loading model and classifying..."): | |
| try: | |
| clf = get_pipeline(cfg["task"], cfg["repo"]) | |
| preds = clf(verse) | |
| # Handle both list of dicts or single dict returned | |
| if isinstance(preds, list) and len(preds) > 0: | |
| # If it's a list of predictions for one input | |
| if isinstance(preds[0], list): | |
| results = preds[0] | |
| else: | |
| results = preds | |
| else: | |
| results = [preds] if isinstance(preds, dict) else [] | |
| # Sort and limit to top-k | |
| results_sorted = sorted(results, key=lambda x: x.get("score", 0), reverse=True)[:topk] | |
| st.subheader("Predictions") | |
| for r in results_sorted: | |
| st.write(f"**{r.get('label','?')}** — {r.get('score', 0):.4f}") | |
| if show_raw: | |
| st.json(preds) | |
| except Exception as e: | |
| st.error(f"Error: {str(e)}") | |
| # ----------------------------- | |
| # Tab 2: Era / Theme classification | |
| # ----------------------------- | |
| with tabs[1]: | |
| section_header("Era / Theme classification", "🗂️") | |
| st.info("Add models for era/theme classification by pasting their repo IDs below.") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.markdown("**Era**") | |
| cfg_era = model_picker("text-classification", context="era") | |
| with col2: | |
| st.markdown("**Theme**") | |
| cfg_theme = model_picker("text-classification", context="theme") | |
| text = st.text_area( | |
| "Paste verse(s) for classification", | |
| height=150, | |
| placeholder="اكتب الأبيات هنا ...", | |
| key="era_theme_text" | |
| ) | |
| topk_et = st.slider("Top-k labels", 1, 10, 5, key="topk_et") | |
| col_btn1, col_btn2 = st.columns(2) | |
| with col_btn1: | |
| run_era = st.button("Classify Era", key="btn_era") | |
| with col_btn2: | |
| run_theme = st.button("Classify Theme", key="btn_theme") | |
| if run_era: | |
| if not cfg_era.get("repo") or not text.strip(): | |
| st.warning("Please provide both a model repo and input text.") | |
| else: | |
| with st.spinner("Classifying era..."): | |
| try: | |
| p = get_pipeline(cfg_era["task"], cfg_era["repo"]) | |
| preds = p(text) | |
| if isinstance(preds, list) and len(preds) > 0: | |
| if isinstance(preds[0], list): | |
| preds = preds[0] | |
| else: | |
| preds = [preds] if isinstance(preds, dict) else [] | |
| preds = sorted(preds, key=lambda x: x.get("score", 0), reverse=True)[:topk_et] | |
| st.subheader("Era predictions") | |
| for r in preds: | |
| st.write(f"**{r.get('label','?')}** — {r.get('score', 0):.4f}") | |
| if show_raw: | |
| st.json(preds) | |
| except Exception as e: | |
| st.error(f"Error: {str(e)}") | |
| if run_theme: | |
| if not cfg_theme.get("repo") or not text.strip(): | |
| st.warning("Please provide both a model repo and input text.") | |
| else: | |
| with st.spinner("Classifying theme..."): | |
| try: | |
| p = get_pipeline(cfg_theme["task"], cfg_theme["repo"]) | |
| preds = p(text) | |
| if isinstance(preds, list) and len(preds) > 0: | |
| if isinstance(preds[0], list): | |
| preds = preds[0] | |
| else: | |
| preds = [preds] if isinstance(preds, dict) else [] | |
| preds = sorted(preds, key=lambda x: x.get("score", 0), reverse=True)[:topk_et] | |
| st.subheader("Theme predictions") | |
| for r in preds: | |
| st.write(f"**{r.get('label','?')}** — {r.get('score', 0):.4f}") | |
| if show_raw: | |
| st.json(preds) | |
| except Exception as e: | |
| st.error(f"Error: {str(e)}") | |
| # ----------------------------- | |
| # Tab 3: Diacritization | |
| # ----------------------------- | |
| with tabs[2]: | |
| section_header("Diacritization (seq2seq)", "🕊️") | |
| cfg_diac = model_picker("text2text-generation", context="diac") | |
| src = st.text_area( | |
| "Undiacritized verse(s)", | |
| height=150, | |
| placeholder="اكتب النص بدون تشكيل ...", | |
| key="diac_src" | |
| ) | |
| max_new = st.slider("Max tokens", 16, 256, 96, key="diac_max") | |
| num_beams = st.slider("Beams", 1, 6, 4, key="diac_beams") | |
| if st.button("Diacritize", type="primary", key="btn_diac"): | |
| if not cfg_diac.get("repo") or not src.strip(): | |
| st.warning("Please provide both a model repo and input text.") | |
| else: | |
| with st.spinner("Diacritizing..."): | |
| try: | |
| p = get_pipeline(cfg_diac["task"], cfg_diac["repo"]) | |
| out = p(src, max_new_tokens=max_new, num_beams=num_beams) | |
| st.subheader("Output") | |
| # Handle different output formats | |
| if isinstance(out, list) and len(out) > 0: | |
| result = out[0] | |
| text_key = "generated_text" if "generated_text" in result else ( | |
| "summary_text" if "summary_text" in result else list(result.keys())[0] | |
| ) | |
| st.write(result[text_key]) | |
| if show_raw: | |
| st.json(out) | |
| except Exception as e: | |
| st.error(f"Error: {str(e)}") | |
| # ----------------------------- | |
| # Tab 4: Poetry generation | |
| # ----------------------------- | |
| with tabs[3]: | |
| section_header("Poetry generation", "📝") | |
| cfg_gen = model_picker("text-generation", context="gen") | |
| prompt = st.text_area( | |
| "Prompt (include hints: meter / qafiyah / theme)", | |
| height=150, | |
| placeholder="مثال: [meter=الطويل, qafiyah=م, theme=غزل]\nيا دارَ مَيّة بالعلياءِ فالسندِ ...", | |
| key="gen_prompt" | |
| ) | |
| max_new = st.slider("Max new tokens", 16, 256, 80, key="gen_max_new") | |
| temp = st.slider("Temperature", 0.1, 1.5, 0.9, 0.1, key="gen_temp") | |
| top_p = st.slider("top_p", 0.1, 1.0, 0.92, 0.01, key="gen_top_p") | |
| top_k = st.slider("top_k", 0, 100, 50, key="gen_top_k") | |
| do_sample = st.checkbox("do_sample", value=True, key="gen_sample") | |
| if st.button("Generate", type="primary", key="btn_gen"): | |
| if not cfg_gen.get("repo") or not prompt.strip(): | |
| st.warning("Please provide both a model repo and a prompt.") | |
| else: | |
| with st.spinner("Generating poetry..."): | |
| try: | |
| p = get_pipeline(cfg_gen["task"], cfg_gen["repo"]) | |
| # Get pad_token_id safely | |
| pad_token_id = p.tokenizer.pad_token_id | |
| if pad_token_id is None: | |
| pad_token_id = p.tokenizer.eos_token_id | |
| out = p( | |
| prompt, | |
| max_new_tokens=max_new, | |
| do_sample=do_sample, | |
| temperature=float(temp), | |
| top_p=float(top_p), | |
| top_k=int(top_k), | |
| pad_token_id=pad_token_id | |
| ) | |
| st.subheader("Generated verse(s)") | |
| if isinstance(out, list) and len(out) > 0: | |
| txt = out[0].get("generated_text", "") | |
| st.write(txt) | |
| if show_raw: | |
| st.json(out) | |
| except Exception as e: | |
| st.error(f"Error: {str(e)}") | |
| # ----------------------------- | |
| # Tab 5: Instructions | |
| # ----------------------------- | |
| with tabs[4]: | |
| section_header("How to use each model", "📘") | |
| st.markdown(""" | |
| ### What each model does | |
| **Meter classification** | |
| - Input: A verse (bayt) or hemistich. | |
| - Output: The most likely **baḥr** (meter) label(s) with scores. | |
| - Recommended models: | |
| - *AraPoemBERT (meter)* — from **Qarah (2024)**. | |
| - *MetRec GRU* — *Al-Shaibani et al.* (14 meters). | |
| - *APCD2 BiLSTM* — *Abandah et al.* (16 meters + prose). | |
| **Era / Theme classification (Ashaar)** | |
| - Input: Verse(s). | |
| - Output: Era (e.g., pre-Islamic, Abbasid…) or Theme (e.g., ghazal, fakhr, heja…). | |
| - Recommended: *Ashaar – Era / Theme classifier*. | |
| **Diacritization** | |
| - Input: Undiacritized verse(s). | |
| - Output: Diacritized text. | |
| - Recommended: *Ashaar – Diacritizer* (text2text-generation / seq2seq). | |
| **Poetry generation** | |
| - Input: Prompt with optional hints: `[meter=..., qafiyah=..., theme=...]` then a seed line. | |
| - Output: Continuation in similar style (try adjusting temperature/top-p). | |
| - Recommended: *Ashaar – Character GPT* (conditional), *AraGPT2 (base)*, *GPT-J (base)*. | |
| > ⚠️ **Note on model repos** | |
| > If a dropdown shows an empty repo, paste the exact Hugging Face ID of the model you want to try (e.g., `faisalq/AraPoemBERT-meter`, `ARBML/ashaar-diacritizer`). | |
| > This keeps the app flexible as you curate your preferred checkpoints. | |
| --- | |
| ### Tips | |
| - For **generation**, lower `temperature` and `top_p` for stricter meter adherence if your checkpoint supports it; increase for more creative output. | |
| - For **classification**, use single lines (or consistent lines) per run for best results. | |
| - If a model is large (e.g., GPT-J), use smaller `max_new_tokens` or consider upgrading to a GPU space. | |
| - On free tier, models load on CPU. First run may be slow as models download and cache. | |
| ### Free Tier Optimizations | |
| - Models use CPU by default (GPU if available) | |
| - Smaller precision (float16) used when GPU is available | |
| - `low_cpu_mem_usage=True` for generation models | |
| - Cached models for faster subsequent runs | |
| """) |