MTASTE / app.py
hienbm's picture
Update app.py
afb810d verified
# app.py β€” hienbm/gemma-2-9b-mtaste-16bit (Gradio Space)
import json
import re
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
# ── Constants ────────────────────────────────────────────────────────────────
MODEL_ID = "hienbm/gemma-2-9b-mtaste-16bit"
ASPECT_CATEGORIES = [
"AMBIENCE#GENERAL", "DRINKS#PRICES", "DRINKS#QUALITY", "DRINKS#STYLE_OPTIONS",
"FOOD#PRICES", "FOOD#QUALITY", "FOOD#STYLE_OPTIONS", "LOCATION#GENERAL",
"RESTAURANT#GENERAL", "RESTAURANT#MISCELLANEOUS", "RESTAURANT#PRICES", "SERVICE#GENERAL",
]
INSTRUCTION = (
"Given a restaurant review, extract all sentiment triplets.\n"
"Read the ENTIRE review first to understand context, sarcasm, and irony.\n"
"Then extract triplets SENTENCE BY SENTENCE in the ORDER they appear.\n\n"
"Output a JSON array sorted by appearance order in the review:\n"
'[{"target": <word/phrase or "NULL">, '
'"aspect": <ASPECT#CATEGORY>, "polarity": <positive|negative|neutral>}]\n\n'
"aspect must be one of: " + ", ".join(ASPECT_CATEGORIES) + "\n\n"
"Rules:\n"
"- Sentence order: extract from sentence 1 first, then sentence 2, etc.\n"
"- Multiple triplets per sentence: one object per triplet, keep order\n"
"- target: exact word/phrase from text, or NULL if implicit\n"
"- Output ONLY the JSON array, no explanation\n\n"
'Example:\nReview: "Food was great. Service was slow."\n'
'Output: [{"target": "food", "aspect": "FOOD#QUALITY", "polarity": "positive"}, '
'{"target": "NULL", "aspect": "SERVICE#GENERAL", "polarity": "negative"}]'
)
POLARITY_EMOJI = {"positive": "🟒", "negative": "πŸ”΄", "neutral": "🟑"}
# ── Model loading ─────────────────────────────────────────────────────────────
bnb_cfg = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
quantization_config=bnb_cfg,
device_map="auto",
torch_dtype=torch.float16,
)
model.eval()
# ── Prompt & parse ────────────────────────────────────────────────────────────
def build_prompt(text: str) -> str:
return (
"<start_of_turn>user\n"
f"{INSTRUCTION}\n\nReview: {text}"
"<end_of_turn>\n"
"<start_of_turn>model\n"
)
def parse_json_output(raw: str) -> list[dict]:
"""Extract the first valid JSON array from model output."""
# Try to find JSON array directly
match = re.search(r"\[.*?\]", raw, re.DOTALL)
if match:
try:
return json.loads(match.group())
except json.JSONDecodeError:
pass
# Fallback: try the whole string
try:
return json.loads(raw)
except json.JSONDecodeError:
return []
# ── Inference ─────────────────────────────────────────────────────────────────
def predict(review_text: str, temperature: float = 0.0) -> tuple[str, str]:
"""
Returns (raw_json_string, formatted_markdown_table).
"""
if not review_text.strip():
return "", "Please enter a review."
prompt = build_prompt(review_text.strip())
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=2048,
temperature=max(temperature, 1e-6),
do_sample=temperature > 0,
pad_token_id=tokenizer.eos_token_id,
)
new_tokens = output[0][inputs["input_ids"].shape[1]:]
raw = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
triplets = parse_json_output(raw)
# Format as markdown table
if not triplets:
table = "No triplets extracted (or JSON parse failed).\n\n**Raw output:**\n```\n" + raw + "\n```"
else:
rows = ["| # | Target | Aspect | Polarity |", "|---|--------|--------|----------|"]
for i, t in enumerate(triplets, 1):
target = t.get("target", "NULL") or "NULL"
aspect = t.get("aspect", "")
pol = t.get("polarity", "").lower()
emoji = POLARITY_EMOJI.get(pol, "")
rows.append(f"| {i} | **{target}** | `{aspect}` | {emoji} {pol} |")
table = "\n".join(rows)
return raw, table
# ── Gradio UI ─────────────────────────────────────────────────────────────────
EXAMPLE_REVIEWS = [
["""Came in here on a random Thursday after work. There was only one server but she was very efficient and very kind. Place was a bit empty but as I was leaving over 8 people were coming in so I just came at an awkward time.
Inside is absolutely gorgeous. So much seating for individuals, small groups, duos, and even large groups!
I got the Spicy Miso Vegetarian Ramen Noodles with non-vegan noodles. It was absolutely delicious!! The broth was clearly very clean and healthy and it was very comforting. Only thing I think can be improved were the tofu pieces which may need their own marinating process or cooking process to give them more flavor... but i ate them with a spoonful of broth and it tasted great so maybe it's not a problem haha.
would totally recommend and will for sure be coming back with more people."""],
["""In all fairness, we came here because of a Groupon and for me that usually indicates the restaurant is trying to push for new customers. We made the drive and glad we did it, but it's not enough to make us return. I wish them success because hibachi is totally fun.
Our chef was excellent! Entertaining enjoyable, all of the things.
That said the restaurant was definitely old and dated but not in the nostalgic way. Food was pretty ok. Pricing is expected but would've been frustrating had it not been for Groupon."""],
["This place serves fast, it's been over 30 minutes and the dish still hasn't come out."]
]
with gr.Blocks(title="MTASTE | Restaurant Review") as demo:
gr.Markdown(
"# 🍽️ MTASTE β€” Multilingual Target Aspect Sentiment Triplet Extraction at the paragraph level\n"
"Extract sentiment triplets *(target, aspect, polarity)* from restaurant reviews.\n\n"
f"**Model:** [`{MODEL_ID}`](https://huggingface.co/{MODEL_ID})"
)
with gr.Row():
with gr.Column(scale=2):
review_input = gr.Textbox(
label="Restaurant Review",
placeholder="Enter a restaurant review here...",
lines=5,
)
temperature = gr.Slider(
minimum=0.0, maximum=1.0, step=0.1, value=0.0,
label="Temperature (0 = greedy)",
)
submit_btn = gr.Button("Extract Triplets", variant="primary")
with gr.Column(scale=3):
table_output = gr.Markdown(label="Extracted Triplets")
raw_output = gr.Textbox(label="Raw JSON output", lines=4)
gr.Examples(examples=EXAMPLE_REVIEWS, inputs=review_input)
submit_btn.click(
fn=predict,
inputs=[review_input, temperature],
outputs=[raw_output, table_output],
)
if __name__ == "__main__":
demo.launch()