NourFakih commited on
Commit
e9c6a59
·
verified ·
1 Parent(s): 3cf5411

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +250 -190
src/streamlit_app.py CHANGED
@@ -1,4 +1,3 @@
1
- # app.py
2
  import os
3
  import json
4
  from typing import Dict, Any, List
@@ -7,9 +6,9 @@ import streamlit as st
7
  from transformers import (
8
  AutoTokenizer, AutoModelForSequenceClassification,
9
  AutoModelForCausalLM, AutoModelForSeq2SeqLM,
10
- TextClassificationPipeline, TextGenerationPipeline, pipeline
11
  )
12
- from functools import lru_cache
13
 
14
  st.set_page_config(
15
  page_title="Arabic Poetry Lab – Meters, Diacritization & Generation",
@@ -20,102 +19,29 @@ st.set_page_config(
20
  # -----------------------------
21
  # Model Registry (edit safely)
22
  # -----------------------------
23
- # Put the exact model repo IDs you want to try here.
24
- # If you're not sure yet, leave as-is; the app will prompt the user to paste custom IDs.
25
  MODEL_REGISTRY = {
26
  # === Meter classification models ===
27
  "AraPoemBERT (meter)": {
28
  "task": "text-classification",
29
- "repo": "faisalq/bert-base-arapoembert", # e.g. "faisalq/AraPoemBERT-meter"
30
  "paper": "AraPoemBERT (Qarah, 2024)",
31
  "notes": "BERT-based poetry LM, SOTA on meter/sub-meter/rhyme tasks."
32
  },
33
- # "MetRec GRU (text meter classifier)": {
34
- # "task": "text-classification",
35
- # "repo": "", # e.g. "arbml/metrec-gru-meter-classifier"
36
- # "paper": "Al-Shaibani et al. (MetRec)",
37
- # "notes": "5-layer GRU; 14 meters; trained on MetRec (55.4k verses)."
38
- # },
39
- # "APCD2 BiLSTM (meter + prose)": {
40
- # "task": "text-classification",
41
- # "repo": "", # e.g. "abandah/apcd2-bilstm-17classes"
42
- # "paper": "Abandah et al. (APCD2)",
43
- # "notes": "Deep BiLSTM; 16 meters + prose (17 classes)."
44
- # },
45
-
46
- # # === Era / theme classifiers (Ashaar suite) ===
47
- # "Ashaar – Meter classifier": {
48
- # "task": "text-classification",
49
- # "repo": "", # e.g. "ARBML/ashaar-meter-classifier"
50
- # "paper": "Ashaar (Alyafeai, Al-Shaibani, Ahmed)",
51
- # "notes": "Character-level or BERT-based meter classifier."
52
- # },
53
- # "Ashaar – Era classifier": {
54
- # "task": "text-classification",
55
- # "repo": "", # e.g. "ARBML/ashaar-era-classifier"
56
- # "paper": "Ashaar (Alyafeai, Al-Shaibani, Ahmed)",
57
- # "notes": "Predicts poem era (e.g., pre-Islamic, Abbasid, etc.)."
58
- # },
59
- # "Ashaar – Theme classifier": {
60
- # "task": "text-classification",
61
- # "repo": "", # e.g. "ARBML/ashaar-theme-classifier"
62
- # "paper": "Ashaar (Alyafeai, Al-Shaibani, Ahmed)",
63
- # "notes": "Predicts poem theme (e.g., ghazal, fakhr, heja...)."
64
- # },
65
-
66
- # # === Diacritization (Ashaar diacritizer or any seq2seq) ===
67
- # "Ashaar – Diacritizer": {
68
- # "task": "text2text-generation",
69
- # "repo": "", # e.g. "ARBML/ashaar-diacritizer"
70
- # "paper": "Ashaar (Alyafeai, Al-Shaibani, Ahmed)",
71
- # "notes": "Takes undiacritized verse → diacritized verse."
72
- # },
73
-
74
- # # === Poetry generation ===
75
- # "Ashaar – Character GPT (conditional)": {
76
- # "task": "text-generation",
77
- # "repo": "", # e.g. "ARBML/ashaar-char-gpt"
78
- # "paper": "Ashaar (Alyafeai, Al-Shaibani, Ahmed)",
79
- # "notes": "Condition on meter/qafiyah/theme in the prompt."
80
- # },
81
  "AraGPT2 (base, Arabic)": {
82
  "task": "text-generation",
83
  "repo": "aubmindlab/aragpt2-base",
84
  "paper": "Antoun et al. (AraGPT2)",
85
  "notes": "Use with prompts that include meter/rhyme hints."
86
  },
87
- # "GPT-J 6B (base)": {
88
- # "task": "text-generation",
89
- # "repo": "EleutherAI/gpt-j-6B",
90
- # "paper": "EleutherAI GPT-J",
91
- # "notes": "Heavy model; enable low VRAM settings if needed."
92
- # },
93
-
94
- # # === Baselines / Classical Arabic encoders (for zero-shot fun) ===
95
- # "CAMeLBERT-CA (baseline encoder)": {
96
- # "task": "fill-mask", # let users try zero/few-shot gimmicks
97
- # "repo": "CAMeL-Lab/bert-base-arabic-camelbert-ca",
98
- # "paper": "Inoue et al. (CAMeLBERT-CA)",
99
- # "notes": "Good for Classical Arabic; not a meter classifier."
100
- # },
101
- # "AraBERTv1 (baseline encoder)": {
102
- # "task": "fill-mask",
103
- # "repo": "aubmindlab/bert-base-arabertv01",
104
- # "paper": "Antoun et al. (AraBERT)",
105
- # "notes": "Modern Standard Arabic baseline."
106
- # },
107
  }
108
 
109
  HELP_TEXT = """
110
  ### What this Space does
111
-
112
  This app lets you **try Arabic poetry models** from the literature:
113
-
114
  - **Meter classification** (text) – predict the baḥr class.
115
  - **Era / Theme classification** (text) – Ashaar suite classifiers.
116
  - **Diacritization** – undiacritized → diacritized verse (seq2seq).
117
  - **Poetry generation** – prompt a model to continue a verse with target meter / rhyme / theme hints.
118
-
119
  > 🔧 **Tip**: For any entry with an empty model repo, paste the exact Hugging Face repo ID (e.g., `faisalq/AraPoemBERT-meter`). You can add your own models too.
120
  """
121
 
@@ -124,35 +50,79 @@ This app lets you **try Arabic poetry models** from the literature:
124
  # -----------------------------
125
  @st.cache_resource(show_spinner=False)
126
  def get_pipeline(task: str, model_id: str):
127
- # Pick a default pipeline based on 'task'
128
- if task == "text-classification":
129
- return pipeline("text-classification", model=model_id, tokenizer=model_id, top_k=None)
130
- elif task == "text2text-generation":
131
- return pipeline("text2text-generation", model=model_id, tokenizer=model_id)
132
- elif task == "text-generation":
133
- tok = AutoTokenizer.from_pretrained(model_id)
134
- mdl = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", device_map="auto")
135
- return TextGenerationPipeline(model=mdl, tokenizer=tok)
136
- elif task == "fill-mask":
137
- return pipeline("fill-mask", model=model_id, tokenizer=model_id)
138
- else:
139
- raise ValueError(f"Unsupported task: {task}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
  def section_header(title, emoji="✨"):
142
  st.markdown(f"## {emoji} {title}")
143
 
144
  def model_picker(task_filter: str) -> Dict[str, Any]:
 
145
  subset = {k: v for k, v in MODEL_REGISTRY.items() if v["task"] == task_filter}
146
  names = list(subset.keys())
 
147
  if not names:
148
- st.warning("No models registered for this task.")
149
- return {}
150
-
151
- choice = st.selectbox("Pick a model", names)
 
 
152
  cfg = subset[choice]
153
- repo = st.text_input("Model repo on Hugging Face", value=cfg["repo"], placeholder="org/model-id")
 
 
 
 
 
154
  st.caption(f"**Paper**: {cfg['paper']} \n**Notes**: {cfg['notes']}")
155
- return {"name": choice, "task": cfg["task"], "repo": repo, "paper": cfg["paper"], "notes": cfg["notes"]}
 
 
 
 
 
 
156
 
157
  # -----------------------------
158
  # Sidebar
@@ -183,72 +153,125 @@ tabs = st.tabs([
183
  with tabs[0]:
184
  section_header("Meter classification (text)", "📏")
185
  cfg = model_picker("text-classification")
186
- verse = st.text_area("Paste a single bayt (verse) or hemistich", height=120, placeholder="اكتب البيت هنا ...")
187
- topk = st.slider("Top-k labels to show", 1, 16, 5)
188
-
189
- if st.button("Classify meter", type="primary", disabled=not (cfg and cfg.get("repo") and verse.strip())):
190
- try:
191
- clf = get_pipeline(cfg["task"], cfg["repo"])
192
- preds = clf(verse)
193
- # Handle both list of dicts or single dict returned
194
- results = preds if isinstance(preds, list) else [preds]
195
-
196
- # Some pipelines return already sorted; ensure top-k
197
- results_sorted = sorted(results, key=lambda x: x.get("score", 0), reverse=True)[:topk]
198
- st.subheader("Predictions")
199
- for r in results_sorted:
200
- st.write(f"**{r.get('label','?')}** {r.get('score', 0):.4f}")
201
- if show_raw:
202
- st.json(preds)
203
- except Exception as e:
204
- st.error(f"Error loading/running model: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
  # -----------------------------
207
  # Tab 2: Era / Theme classification
208
  # -----------------------------
209
  with tabs[1]:
210
  section_header("Era / Theme classification", "🗂️")
 
 
211
  col1, col2 = st.columns(2)
212
  with col1:
213
  st.markdown("**Era**")
214
- cfg_era = {**model_picker("text-classification"), "kind": "era"}
215
  with col2:
216
  st.markdown("**Theme**")
217
- cfg_theme = {**model_picker("text-classification"), "kind": "theme"}
218
 
219
- text = st.text_area("Paste verse(s) for classification", height=150, placeholder="اكتب الأبيات هنا ...")
 
 
 
 
 
220
  topk_et = st.slider("Top-k labels", 1, 10, 5, key="topk_et")
221
 
222
- run_era = st.button("Classify Era", disabled=not (cfg_era and cfg_era.get("repo") and text.strip()))
223
- run_theme = st.button("Classify Theme", disabled=not (cfg_theme and cfg_theme.get("repo") and text.strip()))
 
 
 
224
 
225
  if run_era:
226
- try:
227
- p = get_pipeline(cfg_era["task"], cfg_era["repo"])
228
- preds = p(text)
229
- preds = preds if isinstance(preds, list) else [preds]
230
- preds = sorted(preds, key=lambda x: x.get("score", 0), reverse=True)[:topk_et]
231
- st.subheader("Era predictions")
232
- for r in preds:
233
- st.write(f"**{r.get('label','?')}** — {r.get('score', 0):.4f}")
234
- if show_raw:
235
- st.json(preds)
236
- except Exception as e:
237
- st.error(f"Error: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
238
 
239
  if run_theme:
240
- try:
241
- p = get_pipeline(cfg_theme["task"], cfg_theme["repo"])
242
- preds = p(text)
243
- preds = preds if isinstance(preds, list) else [preds]
244
- preds = sorted(preds, key=lambda x: x.get("score", 0), reverse=True)[:topk_et]
245
- st.subheader("Theme predictions")
246
- for r in preds:
247
- st.write(f"**{r.get('label','?')}** — {r.get('score', 0):.4f}")
248
- if show_raw:
249
- st.json(preds)
250
- except Exception as e:
251
- st.error(f"Error: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
252
 
253
  # -----------------------------
254
  # Tab 3: Diacritization
@@ -256,21 +279,37 @@ with tabs[1]:
256
  with tabs[2]:
257
  section_header("Diacritization (seq2seq)", "🕊️")
258
  cfg_diac = model_picker("text2text-generation")
259
- src = st.text_area("Undiacritized verse(s)", height=150, placeholder="اكتب النص بدون تشكيل ...")
260
- max_new = st.slider("Max tokens", 16, 256, 96)
261
- num_beams = st.slider("Beams", 1, 6, 4)
262
- if st.button("Diacritize", type="primary", disabled=not (cfg_diac and cfg_diac.get("repo") and src.strip())):
263
- try:
264
- p = get_pipeline(cfg_diac["task"], cfg_diac["repo"])
265
- out = p(src, max_new_tokens=max_new, num_beams=num_beams)
266
- st.subheader("Output")
267
- # Typical format: [{"generated_text": "..."}] or [{"summary_text": "..."}]
268
- text_key = "generated_text" if "generated_text" in out[0] else list(out[0].keys())[0]
269
- st.write(out[0][text_key])
270
- if show_raw:
271
- st.json(out)
272
- except Exception as e:
273
- st.error(f"Error: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
 
275
  # -----------------------------
276
  # Tab 4: Poetry generation
@@ -281,33 +320,47 @@ with tabs[3]:
281
  prompt = st.text_area(
282
  "Prompt (include hints: meter / qafiyah / theme)",
283
  height=150,
284
- placeholder="مثال: [meter=الطويل, qafiyah=م, theme=غزل]\nيا دارَ مَيّة بالعلياءِ فالسندِ ..."
 
285
  )
286
- max_new = st.slider("Max new tokens", 16, 256, 80, key="mx_new_gen")
287
- temp = st.slider("Temperature", 0.1, 1.5, 0.9, 0.1)
288
- top_p = st.slider("top_p", 0.1, 1.0, 0.92, 0.01)
289
- top_k = st.slider("top_k", 0, 100, 50)
290
- do_sample = st.checkbox("do_sample", value=True)
291
-
292
- if st.button("Generate", type="primary", disabled=not (cfg_gen and cfg_gen.get("repo") and prompt.strip())):
293
- try:
294
- p = get_pipeline(cfg_gen["task"], cfg_gen["repo"])
295
- out = p(
296
- prompt,
297
- max_new_tokens=max_new,
298
- do_sample=do_sample,
299
- temperature=float(temp),
300
- top_p=float(top_p),
301
- top_k=int(top_k),
302
- pad_token_id=getattr(p.tokenizer, "eos_token_id", None)
303
- )
304
- st.subheader("Generated verse(s)")
305
- txt = out[0].get("generated_text", "")
306
- st.write(txt)
307
- if show_raw:
308
- st.json(out)
309
- except Exception as e:
310
- st.error(f"Error: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
311
 
312
  # -----------------------------
313
  # Tab 5: Instructions
@@ -343,11 +396,18 @@ with tabs[4]:
343
  > ⚠️ **Note on model repos**
344
  > If a dropdown shows an empty repo, paste the exact Hugging Face ID of the model you want to try (e.g., `faisalq/AraPoemBERT-meter`, `ARBML/ashaar-diacritizer`).
345
  > This keeps the app flexible as you curate your preferred checkpoints.
346
- """)
347
- st.markdown("---")
348
- st.markdown("""
349
  ### Tips
350
  - For **generation**, lower `temperature` and `top_p` for stricter meter adherence if your checkpoint supports it; increase for more creative output.
351
  - For **classification**, use single lines (or consistent lines) per run for best results.
352
- - If a model is large (e.g., GPT-J), use smaller `max_new_tokens` or host the Space with a GPU.
353
- """)
 
 
 
 
 
 
 
 
 
1
  import os
2
  import json
3
  from typing import Dict, Any, List
 
6
  from transformers import (
7
  AutoTokenizer, AutoModelForSequenceClassification,
8
  AutoModelForCausalLM, AutoModelForSeq2SeqLM,
9
+ pipeline
10
  )
11
+ import torch
12
 
13
  st.set_page_config(
14
  page_title="Arabic Poetry Lab – Meters, Diacritization & Generation",
 
19
  # -----------------------------
20
  # Model Registry (edit safely)
21
  # -----------------------------
 
 
22
  MODEL_REGISTRY = {
23
  # === Meter classification models ===
24
  "AraPoemBERT (meter)": {
25
  "task": "text-classification",
26
+ "repo": "faisalq/bert-base-arapoembert",
27
  "paper": "AraPoemBERT (Qarah, 2024)",
28
  "notes": "BERT-based poetry LM, SOTA on meter/sub-meter/rhyme tasks."
29
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  "AraGPT2 (base, Arabic)": {
31
  "task": "text-generation",
32
  "repo": "aubmindlab/aragpt2-base",
33
  "paper": "Antoun et al. (AraGPT2)",
34
  "notes": "Use with prompts that include meter/rhyme hints."
35
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  }
37
 
38
  HELP_TEXT = """
39
  ### What this Space does
 
40
  This app lets you **try Arabic poetry models** from the literature:
 
41
  - **Meter classification** (text) – predict the baḥr class.
42
  - **Era / Theme classification** (text) – Ashaar suite classifiers.
43
  - **Diacritization** – undiacritized → diacritized verse (seq2seq).
44
  - **Poetry generation** – prompt a model to continue a verse with target meter / rhyme / theme hints.
 
45
  > 🔧 **Tip**: For any entry with an empty model repo, paste the exact Hugging Face repo ID (e.g., `faisalq/AraPoemBERT-meter`). You can add your own models too.
46
  """
47
 
 
50
  # -----------------------------
51
  @st.cache_resource(show_spinner=False)
52
  def get_pipeline(task: str, model_id: str):
53
+ """Load model pipeline with free tier optimizations"""
54
+ try:
55
+ # Check if GPU is available, but don't force it
56
+ device = 0 if torch.cuda.is_available() else -1
57
+
58
+ if task == "text-classification":
59
+ return pipeline(
60
+ "text-classification",
61
+ model=model_id,
62
+ tokenizer=model_id,
63
+ device=device,
64
+ top_k=None
65
+ )
66
+ elif task == "text2text-generation":
67
+ return pipeline(
68
+ "text2text-generation",
69
+ model=model_id,
70
+ tokenizer=model_id,
71
+ device=device
72
+ )
73
+ elif task == "text-generation":
74
+ # For generation models, use smaller precision on free tier
75
+ return pipeline(
76
+ "text-generation",
77
+ model=model_id,
78
+ tokenizer=model_id,
79
+ device=device,
80
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
81
+ low_cpu_mem_usage=True
82
+ )
83
+ elif task == "fill-mask":
84
+ return pipeline(
85
+ "fill-mask",
86
+ model=model_id,
87
+ tokenizer=model_id,
88
+ device=device
89
+ )
90
+ else:
91
+ raise ValueError(f"Unsupported task: {task}")
92
+ except Exception as e:
93
+ st.error(f"Error loading model: {str(e)}")
94
+ raise
95
 
96
  def section_header(title, emoji="✨"):
97
  st.markdown(f"## {emoji} {title}")
98
 
99
  def model_picker(task_filter: str) -> Dict[str, Any]:
100
+ """Model selection widget"""
101
  subset = {k: v for k, v in MODEL_REGISTRY.items() if v["task"] == task_filter}
102
  names = list(subset.keys())
103
+
104
  if not names:
105
+ st.warning(f"No models registered for task: {task_filter}")
106
+ st.info("You can add a custom model repo ID below.")
107
+ repo = st.text_input("Model repo on Hugging Face", placeholder="org/model-id")
108
+ return {"name": "Custom", "task": task_filter, "repo": repo, "paper": "N/A", "notes": "Custom model"}
109
+
110
+ choice = st.selectbox("Pick a model", names, key=f"picker_{task_filter}_{len(names)}")
111
  cfg = subset[choice]
112
+ repo = st.text_input(
113
+ "Model repo on Hugging Face",
114
+ value=cfg["repo"],
115
+ placeholder="org/model-id",
116
+ key=f"repo_{choice}"
117
+ )
118
  st.caption(f"**Paper**: {cfg['paper']} \n**Notes**: {cfg['notes']}")
119
+ return {
120
+ "name": choice,
121
+ "task": cfg["task"],
122
+ "repo": repo,
123
+ "paper": cfg["paper"],
124
+ "notes": cfg["notes"]
125
+ }
126
 
127
  # -----------------------------
128
  # Sidebar
 
153
  with tabs[0]:
154
  section_header("Meter classification (text)", "📏")
155
  cfg = model_picker("text-classification")
156
+ verse = st.text_area(
157
+ "Paste a single bayt (verse) or hemistich",
158
+ height=120,
159
+ placeholder="اكتب البيت هنا ...",
160
+ key="meter_verse"
161
+ )
162
+ topk = st.slider("Top-k labels to show", 1, 16, 5, key="meter_topk")
163
+
164
+ if st.button("Classify meter", type="primary", key="classify_meter"):
165
+ if not cfg.get("repo") or not verse.strip():
166
+ st.warning("Please provide both a model repo and input text.")
167
+ else:
168
+ with st.spinner("Loading model and classifying..."):
169
+ try:
170
+ clf = get_pipeline(cfg["task"], cfg["repo"])
171
+ preds = clf(verse)
172
+
173
+ # Handle both list of dicts or single dict returned
174
+ if isinstance(preds, list) and len(preds) > 0:
175
+ # If it's a list of predictions for one input
176
+ if isinstance(preds[0], list):
177
+ results = preds[0]
178
+ else:
179
+ results = preds
180
+ else:
181
+ results = [preds] if isinstance(preds, dict) else []
182
+
183
+ # Sort and limit to top-k
184
+ results_sorted = sorted(results, key=lambda x: x.get("score", 0), reverse=True)[:topk]
185
+
186
+ st.subheader("Predictions")
187
+ for r in results_sorted:
188
+ st.write(f"**{r.get('label','?')}** — {r.get('score', 0):.4f}")
189
+
190
+ if show_raw:
191
+ st.json(preds)
192
+ except Exception as e:
193
+ st.error(f"Error: {str(e)}")
194
 
195
  # -----------------------------
196
  # Tab 2: Era / Theme classification
197
  # -----------------------------
198
  with tabs[1]:
199
  section_header("Era / Theme classification", "🗂️")
200
+ st.info("Add models for era/theme classification by pasting their repo IDs below.")
201
+
202
  col1, col2 = st.columns(2)
203
  with col1:
204
  st.markdown("**Era**")
205
+ cfg_era = model_picker("text-classification")
206
  with col2:
207
  st.markdown("**Theme**")
208
+ cfg_theme = model_picker("text-classification")
209
 
210
+ text = st.text_area(
211
+ "Paste verse(s) for classification",
212
+ height=150,
213
+ placeholder="اكتب الأبيات هنا ...",
214
+ key="era_theme_text"
215
+ )
216
  topk_et = st.slider("Top-k labels", 1, 10, 5, key="topk_et")
217
 
218
+ col_btn1, col_btn2 = st.columns(2)
219
+ with col_btn1:
220
+ run_era = st.button("Classify Era", key="btn_era")
221
+ with col_btn2:
222
+ run_theme = st.button("Classify Theme", key="btn_theme")
223
 
224
  if run_era:
225
+ if not cfg_era.get("repo") or not text.strip():
226
+ st.warning("Please provide both a model repo and input text.")
227
+ else:
228
+ with st.spinner("Classifying era..."):
229
+ try:
230
+ p = get_pipeline(cfg_era["task"], cfg_era["repo"])
231
+ preds = p(text)
232
+
233
+ if isinstance(preds, list) and len(preds) > 0:
234
+ if isinstance(preds[0], list):
235
+ preds = preds[0]
236
+ else:
237
+ preds = [preds] if isinstance(preds, dict) else []
238
+
239
+ preds = sorted(preds, key=lambda x: x.get("score", 0), reverse=True)[:topk_et]
240
+
241
+ st.subheader("Era predictions")
242
+ for r in preds:
243
+ st.write(f"**{r.get('label','?')}** — {r.get('score', 0):.4f}")
244
+
245
+ if show_raw:
246
+ st.json(preds)
247
+ except Exception as e:
248
+ st.error(f"Error: {str(e)}")
249
 
250
  if run_theme:
251
+ if not cfg_theme.get("repo") or not text.strip():
252
+ st.warning("Please provide both a model repo and input text.")
253
+ else:
254
+ with st.spinner("Classifying theme..."):
255
+ try:
256
+ p = get_pipeline(cfg_theme["task"], cfg_theme["repo"])
257
+ preds = p(text)
258
+
259
+ if isinstance(preds, list) and len(preds) > 0:
260
+ if isinstance(preds[0], list):
261
+ preds = preds[0]
262
+ else:
263
+ preds = [preds] if isinstance(preds, dict) else []
264
+
265
+ preds = sorted(preds, key=lambda x: x.get("score", 0), reverse=True)[:topk_et]
266
+
267
+ st.subheader("Theme predictions")
268
+ for r in preds:
269
+ st.write(f"**{r.get('label','?')}** — {r.get('score', 0):.4f}")
270
+
271
+ if show_raw:
272
+ st.json(preds)
273
+ except Exception as e:
274
+ st.error(f"Error: {str(e)}")
275
 
276
  # -----------------------------
277
  # Tab 3: Diacritization
 
279
  with tabs[2]:
280
  section_header("Diacritization (seq2seq)", "🕊️")
281
  cfg_diac = model_picker("text2text-generation")
282
+ src = st.text_area(
283
+ "Undiacritized verse(s)",
284
+ height=150,
285
+ placeholder="اكتب النص بدون تشكيل ...",
286
+ key="diac_src"
287
+ )
288
+ max_new = st.slider("Max tokens", 16, 256, 96, key="diac_max")
289
+ num_beams = st.slider("Beams", 1, 6, 4, key="diac_beams")
290
+
291
+ if st.button("Diacritize", type="primary", key="btn_diac"):
292
+ if not cfg_diac.get("repo") or not src.strip():
293
+ st.warning("Please provide both a model repo and input text.")
294
+ else:
295
+ with st.spinner("Diacritizing..."):
296
+ try:
297
+ p = get_pipeline(cfg_diac["task"], cfg_diac["repo"])
298
+ out = p(src, max_new_tokens=max_new, num_beams=num_beams)
299
+
300
+ st.subheader("Output")
301
+ # Handle different output formats
302
+ if isinstance(out, list) and len(out) > 0:
303
+ result = out[0]
304
+ text_key = "generated_text" if "generated_text" in result else (
305
+ "summary_text" if "summary_text" in result else list(result.keys())[0]
306
+ )
307
+ st.write(result[text_key])
308
+
309
+ if show_raw:
310
+ st.json(out)
311
+ except Exception as e:
312
+ st.error(f"Error: {str(e)}")
313
 
314
  # -----------------------------
315
  # Tab 4: Poetry generation
 
320
  prompt = st.text_area(
321
  "Prompt (include hints: meter / qafiyah / theme)",
322
  height=150,
323
+ placeholder="مثال: [meter=الطويل, qafiyah=م, theme=غزل]\nيا دارَ مَيّة بالعلياءِ فالسندِ ...",
324
+ key="gen_prompt"
325
  )
326
+ max_new = st.slider("Max new tokens", 16, 256, 80, key="gen_max_new")
327
+ temp = st.slider("Temperature", 0.1, 1.5, 0.9, 0.1, key="gen_temp")
328
+ top_p = st.slider("top_p", 0.1, 1.0, 0.92, 0.01, key="gen_top_p")
329
+ top_k = st.slider("top_k", 0, 100, 50, key="gen_top_k")
330
+ do_sample = st.checkbox("do_sample", value=True, key="gen_sample")
331
+
332
+ if st.button("Generate", type="primary", key="btn_gen"):
333
+ if not cfg_gen.get("repo") or not prompt.strip():
334
+ st.warning("Please provide both a model repo and a prompt.")
335
+ else:
336
+ with st.spinner("Generating poetry..."):
337
+ try:
338
+ p = get_pipeline(cfg_gen["task"], cfg_gen["repo"])
339
+
340
+ # Get pad_token_id safely
341
+ pad_token_id = p.tokenizer.pad_token_id
342
+ if pad_token_id is None:
343
+ pad_token_id = p.tokenizer.eos_token_id
344
+
345
+ out = p(
346
+ prompt,
347
+ max_new_tokens=max_new,
348
+ do_sample=do_sample,
349
+ temperature=float(temp),
350
+ top_p=float(top_p),
351
+ top_k=int(top_k),
352
+ pad_token_id=pad_token_id
353
+ )
354
+
355
+ st.subheader("Generated verse(s)")
356
+ if isinstance(out, list) and len(out) > 0:
357
+ txt = out[0].get("generated_text", "")
358
+ st.write(txt)
359
+
360
+ if show_raw:
361
+ st.json(out)
362
+ except Exception as e:
363
+ st.error(f"Error: {str(e)}")
364
 
365
  # -----------------------------
366
  # Tab 5: Instructions
 
396
  > ⚠️ **Note on model repos**
397
  > If a dropdown shows an empty repo, paste the exact Hugging Face ID of the model you want to try (e.g., `faisalq/AraPoemBERT-meter`, `ARBML/ashaar-diacritizer`).
398
  > This keeps the app flexible as you curate your preferred checkpoints.
399
+
400
+ ---
401
+
402
  ### Tips
403
  - For **generation**, lower `temperature` and `top_p` for stricter meter adherence if your checkpoint supports it; increase for more creative output.
404
  - For **classification**, use single lines (or consistent lines) per run for best results.
405
+ - If a model is large (e.g., GPT-J), use smaller `max_new_tokens` or consider upgrading to a GPU space.
406
+ - On free tier, models load on CPU. First run may be slow as models download and cache.
407
+
408
+ ### Free Tier Optimizations
409
+ - Models use CPU by default (GPU if available)
410
+ - Smaller precision (float16) used when GPU is available
411
+ - `low_cpu_mem_usage=True` for generation models
412
+ - Cached models for faster subsequent runs
413
+ """)