drixo's picture
Add Gradio Space app, push_to_hub, README, fix train/test paths
69abda4
"""
Hugging Face Space: Multilingual Document Assistant
Run this as a Gradio app on Hugging Face Spaces.
Set HF_MODEL_ID to your Hub model (e.g. your-username/multilingual-doc-assistant).
"""
import os
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
# Model: Hub id (e.g. your-username/multilingual-doc-assistant) or local path.
# On Spaces set HF_MODEL_ID in Settings → Variables. Local: use trained folder if present.
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
_LOCAL_MODEL = os.path.join(_SCRIPT_DIR, "multilingual-doc-model")
HF_MODEL_ID = os.environ.get("HF_MODEL_ID") or (_LOCAL_MODEL if os.path.isdir(_LOCAL_MODEL) else "bigscience/bloom-560m")
def load_pipeline():
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(HF_MODEL_ID)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
device = 0 if torch.cuda.is_available() else -1
return pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device=device,
)
# Load once at startup (Spaces will cache)
pipe = load_pipeline()
def _get_text(content):
"""Extract plain text from Gradio message content (str or list of parts)."""
if isinstance(content, str):
return content
if isinstance(content, list):
for part in content:
if isinstance(part, dict) and part.get("type") == "text":
return part.get("text", "")
if isinstance(part, str):
return part
return ""
def build_prompt(history, message):
parts = []
for turn in history:
if isinstance(turn, (list, tuple)) and len(turn) >= 2:
user_msg, assistant_msg = str(turn[0] or ""), str(turn[1] or "")
elif isinstance(turn, dict):
role = turn.get("role", "")
content = _get_text(turn.get("content", ""))
if role == "user":
user_msg, assistant_msg = content, ""
else:
user_msg, assistant_msg = "", content
if not user_msg and not assistant_msg:
continue
else:
continue
if user_msg:
parts.append(f"User: {user_msg}\nAssistant: {assistant_msg}")
parts.append(f"User: {message}\nAssistant:")
return "\n".join(parts)
def chat(message, history):
if not message.strip():
return ""
prompt = build_prompt(history, message)
out = pipe(
prompt,
max_new_tokens=150,
do_sample=True,
temperature=0.7,
pad_token_id=pipe.tokenizer.pad_token_id,
)
full = out[0]["generated_text"]
# Return only the new Assistant part (after the last "Assistant:")
if "Assistant:" in full:
reply = full.split("Assistant:")[-1].strip()
else:
reply = full[len(prompt):].strip()
# Stop at next "User:" or double newline
for stop in ["\nUser:", "\n\nUser:"]:
if stop in reply:
reply = reply.split(stop)[0].strip()
return reply
with gr.Blocks(
title="Multilingual Document Assistant",
theme=gr.themes.Soft(),
) as demo:
gr.Markdown("""
# Multilingual Document Assistant
**Supports:** Spanish · Chinese · Vietnamese · Portuguese
Ask about documents, get explanations, or chat. *(Agent-style responses)*
""")
gr.ChatInterface(
fn=chat,
type="messages",
examples=[
["Explícame este documento: La IA mejora la productividad."],
["总结这段文字: 人工智能正在改变世界。"],
["Giải thích đoạn này: Công nghệ giúp cuộc sống dễ dàng hơn."],
],
retry_btn="Retry",
undo_btn="Undo",
clear_btn="Clear",
)
gr.Markdown(f"*Model: `{HF_MODEL_ID}`*")
if __name__ == "__main__":
demo.launch()