Test-Arabic-models / src /streamlit_app.py
NourFakih's picture
Update src/streamlit_app.py
785a09c verified
import os
import tempfile
# Set Streamlit config paths to writable temp directory
# This prevents permission errors on HF Spaces
temp_dir = tempfile.gettempdir()
os.environ['STREAMLIT_SERVER_FILE_WATCHER_TYPE'] = 'none'
os.environ['STREAMLIT_BROWSER_GATHER_USAGE_STATS'] = 'false'
os.environ['STREAMLIT_THEME_BASE'] = 'light'
os.environ['HOME'] = temp_dir
import json
from typing import Dict, Any, List
import uuid
import streamlit as st
from transformers import (
AutoTokenizer, AutoModelForSequenceClassification,
AutoModelForCausalLM, AutoModelForSeq2SeqLM,
pipeline
)
import torch
st.set_page_config(
page_title="Arabic Poetry Lab – Meters, Diacritization & Generation",
page_icon="🕊️",
layout="wide"
)
# -----------------------------
# Model Registry (edit safely)
# -----------------------------
MODEL_REGISTRY = {
# === Meter classification models ===
"AraPoemBERT (meter)": {
"task": "text-classification",
"repo": "faisalq/bert-base-arapoembert",
"paper": "AraPoemBERT (Qarah, 2024)",
"notes": "BERT-based poetry LM, SOTA on meter/sub-meter/rhyme tasks."
},
"AraGPT2 (base, Arabic)": {
"task": "text-generation",
"repo": "aubmindlab/aragpt2-base",
"paper": "Antoun et al. (AraGPT2)",
"notes": "Use with prompts that include meter/rhyme hints."
},
}
HELP_TEXT = """
### What this Space does
This app lets you **try Arabic poetry models** from the literature:
- **Meter classification** (text) – predict the baḥr class.
- **Era / Theme classification** (text) – Ashaar suite classifiers.
- **Diacritization** – undiacritized → diacritized verse (seq2seq).
- **Poetry generation** – prompt a model to continue a verse with target meter / rhyme / theme hints.
> 🔧 **Tip**: For any entry with an empty model repo, paste the exact Hugging Face repo ID (e.g., `faisalq/AraPoemBERT-meter`). You can add your own models too.
"""
# -----------------------------
# Caching model pipelines
# -----------------------------
@st.cache_resource(show_spinner=False)
def get_pipeline(task: str, model_id: str):
"""Load model pipeline with free tier optimizations"""
try:
# Check if GPU is available, but don't force it
device = 0 if torch.cuda.is_available() else -1
if task == "text-classification":
return pipeline(
"text-classification",
model=model_id,
tokenizer=model_id,
device=device,
top_k=None
)
elif task == "text2text-generation":
return pipeline(
"text2text-generation",
model=model_id,
tokenizer=model_id,
device=device
)
elif task == "text-generation":
# For generation models, use smaller precision on free tier
return pipeline(
"text-generation",
model=model_id,
tokenizer=model_id,
device=device,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
low_cpu_mem_usage=True
)
elif task == "fill-mask":
return pipeline(
"fill-mask",
model=model_id,
tokenizer=model_id,
device=device
)
else:
raise ValueError(f"Unsupported task: {task}")
except Exception as e:
st.error(f"Error loading model: {str(e)}")
raise
def section_header(title, emoji="✨"):
st.markdown(f"## {emoji} {title}")
def model_picker(task_filter: str, context: str = "") -> Dict[str, Any]:
"""Model selection widget with unique keys"""
subset = {k: v for k, v in MODEL_REGISTRY.items() if v["task"] == task_filter}
names = list(subset.keys())
# Create unique key suffix
unique_suffix = f"{context}_{task_filter}_{uuid.uuid4().hex[:8]}"
if not names:
st.warning(f"No models registered for task: {task_filter}")
st.info("You can add a custom model repo ID below.")
repo = st.text_input(
"Model repo on Hugging Face",
placeholder="org/model-id",
key=f"repo_custom_{unique_suffix}"
)
return {
"name": "Custom",
"task": task_filter,
"repo": repo,
"paper": "N/A",
"notes": "Custom model"
}
choice = st.selectbox(
"Pick a model",
names,
key=f"picker_{unique_suffix}"
)
cfg = subset[choice]
repo = st.text_input(
"Model repo on Hugging Face",
value=cfg["repo"],
placeholder="org/model-id",
key=f"repo_{unique_suffix}"
)
st.caption(f"**Paper**: {cfg['paper']} \n**Notes**: {cfg['notes']}")
return {
"name": choice,
"task": cfg["task"],
"repo": repo,
"paper": cfg["paper"],
"notes": cfg["notes"]
}
# -----------------------------
# Sidebar
# -----------------------------
with st.sidebar:
st.title("Arabic Poetry Lab")
st.info("Plug your model repo IDs, then run 🔽")
st.markdown(HELP_TEXT)
st.markdown("---")
st.markdown("**Quick admin**")
show_raw = st.checkbox("Show raw HF output", value=False)
st.caption("Raw = full JSON from transformers pipeline")
st.title("🕊️ Arabic Poetry Lab on HF")
st.write("Try meter classifiers, diacritizers, and generators from the literature.")
tabs = st.tabs([
"Meter classification",
"Era / Theme classification",
"Diacritization",
"Poetry generation",
"Instructions"
])
# -----------------------------
# Tab 1: Meter classification
# -----------------------------
with tabs[0]:
section_header("Meter classification (text)", "📏")
cfg = model_picker("text-classification", context="meter")
verse = st.text_area(
"Paste a single bayt (verse) or hemistich",
height=120,
placeholder="اكتب البيت هنا ...",
key="meter_verse"
)
topk = st.slider("Top-k labels to show", 1, 16, 5, key="meter_topk")
if st.button("Classify meter", type="primary", key="classify_meter"):
if not cfg.get("repo") or not verse.strip():
st.warning("Please provide both a model repo and input text.")
else:
with st.spinner("Loading model and classifying..."):
try:
clf = get_pipeline(cfg["task"], cfg["repo"])
preds = clf(verse)
# Handle both list of dicts or single dict returned
if isinstance(preds, list) and len(preds) > 0:
# If it's a list of predictions for one input
if isinstance(preds[0], list):
results = preds[0]
else:
results = preds
else:
results = [preds] if isinstance(preds, dict) else []
# Sort and limit to top-k
results_sorted = sorted(results, key=lambda x: x.get("score", 0), reverse=True)[:topk]
st.subheader("Predictions")
for r in results_sorted:
st.write(f"**{r.get('label','?')}** — {r.get('score', 0):.4f}")
if show_raw:
st.json(preds)
except Exception as e:
st.error(f"Error: {str(e)}")
# -----------------------------
# Tab 2: Era / Theme classification
# -----------------------------
with tabs[1]:
section_header("Era / Theme classification", "🗂️")
st.info("Add models for era/theme classification by pasting their repo IDs below.")
col1, col2 = st.columns(2)
with col1:
st.markdown("**Era**")
cfg_era = model_picker("text-classification", context="era")
with col2:
st.markdown("**Theme**")
cfg_theme = model_picker("text-classification", context="theme")
text = st.text_area(
"Paste verse(s) for classification",
height=150,
placeholder="اكتب الأبيات هنا ...",
key="era_theme_text"
)
topk_et = st.slider("Top-k labels", 1, 10, 5, key="topk_et")
col_btn1, col_btn2 = st.columns(2)
with col_btn1:
run_era = st.button("Classify Era", key="btn_era")
with col_btn2:
run_theme = st.button("Classify Theme", key="btn_theme")
if run_era:
if not cfg_era.get("repo") or not text.strip():
st.warning("Please provide both a model repo and input text.")
else:
with st.spinner("Classifying era..."):
try:
p = get_pipeline(cfg_era["task"], cfg_era["repo"])
preds = p(text)
if isinstance(preds, list) and len(preds) > 0:
if isinstance(preds[0], list):
preds = preds[0]
else:
preds = [preds] if isinstance(preds, dict) else []
preds = sorted(preds, key=lambda x: x.get("score", 0), reverse=True)[:topk_et]
st.subheader("Era predictions")
for r in preds:
st.write(f"**{r.get('label','?')}** — {r.get('score', 0):.4f}")
if show_raw:
st.json(preds)
except Exception as e:
st.error(f"Error: {str(e)}")
if run_theme:
if not cfg_theme.get("repo") or not text.strip():
st.warning("Please provide both a model repo and input text.")
else:
with st.spinner("Classifying theme..."):
try:
p = get_pipeline(cfg_theme["task"], cfg_theme["repo"])
preds = p(text)
if isinstance(preds, list) and len(preds) > 0:
if isinstance(preds[0], list):
preds = preds[0]
else:
preds = [preds] if isinstance(preds, dict) else []
preds = sorted(preds, key=lambda x: x.get("score", 0), reverse=True)[:topk_et]
st.subheader("Theme predictions")
for r in preds:
st.write(f"**{r.get('label','?')}** — {r.get('score', 0):.4f}")
if show_raw:
st.json(preds)
except Exception as e:
st.error(f"Error: {str(e)}")
# -----------------------------
# Tab 3: Diacritization
# -----------------------------
with tabs[2]:
section_header("Diacritization (seq2seq)", "🕊️")
cfg_diac = model_picker("text2text-generation", context="diac")
src = st.text_area(
"Undiacritized verse(s)",
height=150,
placeholder="اكتب النص بدون تشكيل ...",
key="diac_src"
)
max_new = st.slider("Max tokens", 16, 256, 96, key="diac_max")
num_beams = st.slider("Beams", 1, 6, 4, key="diac_beams")
if st.button("Diacritize", type="primary", key="btn_diac"):
if not cfg_diac.get("repo") or not src.strip():
st.warning("Please provide both a model repo and input text.")
else:
with st.spinner("Diacritizing..."):
try:
p = get_pipeline(cfg_diac["task"], cfg_diac["repo"])
out = p(src, max_new_tokens=max_new, num_beams=num_beams)
st.subheader("Output")
# Handle different output formats
if isinstance(out, list) and len(out) > 0:
result = out[0]
text_key = "generated_text" if "generated_text" in result else (
"summary_text" if "summary_text" in result else list(result.keys())[0]
)
st.write(result[text_key])
if show_raw:
st.json(out)
except Exception as e:
st.error(f"Error: {str(e)}")
# -----------------------------
# Tab 4: Poetry generation
# -----------------------------
with tabs[3]:
section_header("Poetry generation", "📝")
cfg_gen = model_picker("text-generation", context="gen")
prompt = st.text_area(
"Prompt (include hints: meter / qafiyah / theme)",
height=150,
placeholder="مثال: [meter=الطويل, qafiyah=م, theme=غزل]\nيا دارَ مَيّة بالعلياءِ فالسندِ ...",
key="gen_prompt"
)
max_new = st.slider("Max new tokens", 16, 256, 80, key="gen_max_new")
temp = st.slider("Temperature", 0.1, 1.5, 0.9, 0.1, key="gen_temp")
top_p = st.slider("top_p", 0.1, 1.0, 0.92, 0.01, key="gen_top_p")
top_k = st.slider("top_k", 0, 100, 50, key="gen_top_k")
do_sample = st.checkbox("do_sample", value=True, key="gen_sample")
if st.button("Generate", type="primary", key="btn_gen"):
if not cfg_gen.get("repo") or not prompt.strip():
st.warning("Please provide both a model repo and a prompt.")
else:
with st.spinner("Generating poetry..."):
try:
p = get_pipeline(cfg_gen["task"], cfg_gen["repo"])
# Get pad_token_id safely
pad_token_id = p.tokenizer.pad_token_id
if pad_token_id is None:
pad_token_id = p.tokenizer.eos_token_id
out = p(
prompt,
max_new_tokens=max_new,
do_sample=do_sample,
temperature=float(temp),
top_p=float(top_p),
top_k=int(top_k),
pad_token_id=pad_token_id
)
st.subheader("Generated verse(s)")
if isinstance(out, list) and len(out) > 0:
txt = out[0].get("generated_text", "")
st.write(txt)
if show_raw:
st.json(out)
except Exception as e:
st.error(f"Error: {str(e)}")
# -----------------------------
# Tab 5: Instructions
# -----------------------------
with tabs[4]:
section_header("How to use each model", "📘")
st.markdown("""
### What each model does
**Meter classification**
- Input: A verse (bayt) or hemistich.
- Output: The most likely **baḥr** (meter) label(s) with scores.
- Recommended models:
- *AraPoemBERT (meter)* — from **Qarah (2024)**.
- *MetRec GRU* — *Al-Shaibani et al.* (14 meters).
- *APCD2 BiLSTM* — *Abandah et al.* (16 meters + prose).
**Era / Theme classification (Ashaar)**
- Input: Verse(s).
- Output: Era (e.g., pre-Islamic, Abbasid…) or Theme (e.g., ghazal, fakhr, heja…).
- Recommended: *Ashaar – Era / Theme classifier*.
**Diacritization**
- Input: Undiacritized verse(s).
- Output: Diacritized text.
- Recommended: *Ashaar – Diacritizer* (text2text-generation / seq2seq).
**Poetry generation**
- Input: Prompt with optional hints: `[meter=..., qafiyah=..., theme=...]` then a seed line.
- Output: Continuation in similar style (try adjusting temperature/top-p).
- Recommended: *Ashaar – Character GPT* (conditional), *AraGPT2 (base)*, *GPT-J (base)*.
> ⚠️ **Note on model repos**
> If a dropdown shows an empty repo, paste the exact Hugging Face ID of the model you want to try (e.g., `faisalq/AraPoemBERT-meter`, `ARBML/ashaar-diacritizer`).
> This keeps the app flexible as you curate your preferred checkpoints.
---
### Tips
- For **generation**, lower `temperature` and `top_p` for stricter meter adherence if your checkpoint supports it; increase for more creative output.
- For **classification**, use single lines (or consistent lines) per run for best results.
- If a model is large (e.g., GPT-J), use smaller `max_new_tokens` or consider upgrading to a GPU space.
- On free tier, models load on CPU. First run may be slow as models download and cache.
### Free Tier Optimizations
- Models use CPU by default (GPU if available)
- Smaller precision (float16) used when GPU is available
- `low_cpu_mem_usage=True` for generation models
- Cached models for faster subsequent runs
""")