olumideola's picture
Update app.py
5915f8c verified
Raw
History Blame Contribute Delete
6.31 kB
"""mist-qg-1.5b demo — passage in, questions out, 25 languages.
Run locally with: python app.py
On HF Spaces: this file + requirements.txt + README.md (with the
Spaces YAML header) is the whole deployment.
"""
import json
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
MODEL_ID = "olaverse/mist-qg-1.5b"
LANG_NAMES = {
"eng": "English", "fra": "French", "deu": "German", "spa": "Spanish",
"por": "Portuguese", "ita": "Italian", "nld": "Dutch", "rus": "Russian",
"pol": "Polish", "tur": "Turkish", "vie": "Vietnamese", "ind": "Indonesian",
"hin": "Hindi", "jpn": "Japanese", "kor": "Korean",
"yor": "Yoruba", "ibo": "Igbo", "hau": "Hausa", "swh": "Swahili",
"amh": "Amharic", "zul": "Zulu", "xho": "Xhosa", "sna": "Shona",
"som": "Somali", "afr": "Afrikaans",
}
# languages flagged on the model card as lower-confidence — shown as a caption
# in the UI when selected, not hidden, so people can still try them
WEAK_LANGS = {"amh", "som", "sna"}
TEACHER_SYSTEM = "You write search-style questions that a passage directly answers."
TEACHER_TEMPLATE = """You are given a passage. Write {n} questions that the passage directly answers.
Rules:
- Every question MUST be answerable using ONLY this passage.
- NEVER copy or repeat a sentence from the passage.
- Rewrite the information into a natural question.
- Questions should sound like something a real person would ask in a search engine.
- Do not quote the passage.
- Vary the question types: factual, yes/no, why/how, comparison.
- Write all questions in {language}.
- Return ONLY valid JSON:
{{"questions": [{slots}]}}
Passage: {passage}"""
EXAMPLES = [
["Tides are caused by the gravitational pull of the moon and, to a lesser extent, the sun, acting on Earth's oceans.", "eng", 3],
["Les marées sont causées par l'attraction gravitationnelle de la lune et, dans une moindre mesure, du soleil, agissant sur les océans de la Terre.", "fra", 3],
["Photosynthesis converts sunlight into chemical energy in plants, using carbon dioxide and water to produce glucose and oxygen.", "eng", 3],
["Ìjì máa ń wáyé nítorí agbára ìfàmọ́ra òṣùpá àti oòrùn lórí omi inú òkun ayé.", "yor", 3],
]
_model = None
_tok = None
def _load():
global _model, _tok
if _model is None:
_tok = AutoTokenizer.from_pretrained(MODEL_ID)
_model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
)
_model.eval()
return _model, _tok
def generate(passage: str, language_code: str, n: int):
passage = (passage or "").strip()
if not passage:
return "", "⚠️ Enter a passage first."
if len(passage) < 20:
return "", "⚠️ Passage is very short — results may be poor."
model, tok = _load()
lang_name = LANG_NAMES.get(language_code, language_code)
slots = ", ".join(['"..."'] * int(n))
user = TEACHER_TEMPLATE.format(
n=int(n),
language=lang_name,
slots=slots,
passage=passage,
)
messages = [
{"role": "system", "content": TEACHER_SYSTEM},
{"role": "user", "content": user},
]
# Build model inputs
inputs = tok.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
return_dict=True,
)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
out = model.generate(
**inputs,
max_new_tokens=250,
do_sample=False,
pad_token_id=tok.pad_token_id
if tok.pad_token_id is not None
else tok.eos_token_id,
)
prompt_len = inputs["input_ids"].shape[1]
text = tok.decode(
out[0][prompt_len:],
skip_special_tokens=True,
)
try:
start = text.index("{")
end = text.rindex("}") + 1
obj = json.loads(text[start:end])
questions = [
q
for q in obj.get("questions", [])
if isinstance(q, str) and q.strip()
]
except (ValueError, json.JSONDecodeError):
questions = []
note = ""
if language_code in WEAK_LANGS:
note = (
f"\n\n⚠️ {lang_name} is one of this model's lower-confidence languages "
f"(see the [model card](https://huggingface.co/{MODEL_ID}) for benchmark numbers)."
)
if not questions:
return (
text,
f"⚠️ Couldn't parse valid questions from the model's output. Raw output shown alongside.{note}",
)
formatted = "\n".join(
f"{i + 1}. {q}" for i, q in enumerate(questions)
)
return formatted, ("✅ Generated." + note) if note else "✅ Generated."
with gr.Blocks(title="mist-qg-1.5b — multilingual question generator") as demo:
gr.Markdown(
"# mist-qg-1.5b\n"
"Passage in, search-style questions out — across 25 languages. "
f"[Model card](https://huggingface.co/{MODEL_ID}) · "
"[Training data](https://huggingface.co/datasets/olaverse/qg-passages-multi)"
)
with gr.Row():
with gr.Column():
passage_in = gr.Textbox(
label="Passage", lines=6,
placeholder="Paste a paragraph of text here...",
)
lang_in = gr.Dropdown(
choices=[(name, code) for code, name in sorted(LANG_NAMES.items(), key=lambda x: x[1])],
value="eng", label="Language",
)
n_in = gr.Slider(1, 5, value=3, step=1, label="Number of questions")
btn = gr.Button("Generate Questions", variant="primary")
with gr.Column():
questions_out = gr.Textbox(label="Generated questions", lines=8, interactive=False)
status_out = gr.Markdown()
gr.Examples(
examples=EXAMPLES,
inputs=[passage_in, lang_in, n_in],
label="Try an example",
)
btn.click(generate, inputs=[passage_in, lang_in, n_in], outputs=[questions_out, status_out])
if __name__ == "__main__":
demo.launch()