metricalT5 / app.py
bsg25
Add aligned view tab for rhythm and meter analysis
e8e90f9
import gradio as gr
from transformers import T5Tokenizer, T5ForConditionalGeneration
tokenizer = T5Tokenizer.from_pretrained("bensglaser/metricalT5-large", legacy=False)
model = T5ForConditionalGeneration.from_pretrained("bensglaser/metricalT5-large")
def transform_to_meter(text):
input_text = f"transform to meter: {text}"
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
outputs = model.generate(**inputs, max_length=128)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
return result
def transform_to_rhythm(text):
input_text = f"transform to rhythm: {text}"
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
outputs = model.generate(**inputs, max_length=128)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
return result
def predict_complexity(text):
input_text = f"predict complexity: {text}"
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
outputs = model.generate(**inputs, max_length=128)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
return result
def run_all_transformations(text):
meter = transform_to_meter(text)
rhythm = transform_to_rhythm(text)
complexity = predict_complexity(text)
combined = f"**Rhythm:**\n{rhythm}\n\n**Original Text:**\n{text}\n\n**Meter:**\n{meter}\n\n**Complexity:**\n{complexity}"
return combined
def create_aligned_table(text):
"""Create an aligned table with rhythm, original text, and meter"""
meter = transform_to_meter(text)
rhythm = transform_to_rhythm(text)
# Split text into syllables (simple word-based splitting for now)
words = text.split()
rhythm_chars = list(rhythm.replace(" ", ""))
meter_chars = list(meter.replace(" ", ""))
# Build aligned rows
rhythm_row = []
text_row = []
meter_row = []
rhythm_idx = 0
meter_idx = 0
for word in words:
# For each word, try to align rhythm and meter characters
word_len = len(word)
# Get rhythm chars for this word
rhythm_segment = ""
for _ in range(word_len):
if rhythm_idx < len(rhythm_chars):
rhythm_segment += rhythm_chars[rhythm_idx]
rhythm_idx += 1
# Get meter chars for this word
meter_segment = ""
for _ in range(word_len):
if meter_idx < len(meter_chars):
meter_segment += meter_chars[meter_idx]
meter_idx += 1
# Pad to match word length
rhythm_segment = rhythm_segment.ljust(word_len)
meter_segment = meter_segment.ljust(word_len)
rhythm_row.append(rhythm_segment)
text_row.append(word)
meter_row.append(meter_segment)
# Join with spaces
rhythm_line = " ".join(rhythm_row)
text_line = " ".join(text_row)
meter_line = " ".join(meter_row)
# Create HTML table
html_table = f"""
<table style="font-family: monospace; border-collapse: collapse; width: 100%;">
<tr style="border-bottom: 1px solid #ddd;">
<td style="padding: 8px;"><strong>Rhythm:</strong></td>
<td style="padding: 8px;">{rhythm_line}</td>
</tr>
<tr style="border-bottom: 1px solid #ddd;">
<td style="padding: 8px;"><strong>Original:</strong></td>
<td style="padding: 8px;">{text_line}</td>
</tr>
<tr>
<td style="padding: 8px;"><strong>Meter:</strong></td>
<td style="padding: 8px;">{meter_line}</td>
</tr>
</table>
"""
return html_table
# Create Gradio interface with tabs
with gr.Blocks() as demo:
gr.Markdown("# MetricalT5 Model")
gr.Markdown("Analyze poetic meter, rhythm, and complexity using T5")
with gr.Tab("Transform to Meter"):
meter_input = gr.Textbox(lines=5, placeholder="Enter text here...")
meter_output = gr.Textbox(label="Meter Analysis")
meter_button = gr.Button("Analyze Meter")
meter_button.click(transform_to_meter, inputs=meter_input, outputs=meter_output)
with gr.Tab("Transform to Rhythm"):
rhythm_input = gr.Textbox(lines=5, placeholder="Enter text here...")
rhythm_output = gr.Textbox(label="Rhythm Analysis")
rhythm_button = gr.Button("Analyze Rhythm")
rhythm_button.click(transform_to_rhythm, inputs=rhythm_input, outputs=rhythm_output)
with gr.Tab("Predict Complexity"):
complexity_input = gr.Textbox(lines=5, placeholder="Enter text here...")
complexity_output = gr.Textbox(label="Complexity Prediction")
complexity_button = gr.Button("Predict Complexity")
complexity_button.click(predict_complexity, inputs=complexity_input, outputs=complexity_output)
with gr.Tab("All Transformations"):
all_input = gr.Textbox(lines=5, placeholder="Enter text here...")
all_output = gr.Textbox(label="All Results", lines=10)
all_button = gr.Button("Run All Transformations")
all_button.click(run_all_transformations, inputs=all_input, outputs=all_output)
with gr.Tab("Aligned View"):
aligned_input = gr.Textbox(lines=5, placeholder="Enter text here...")
aligned_output = gr.HTML(label="Aligned Table")
aligned_button = gr.Button("Generate Aligned View")
aligned_button.click(create_aligned_table, inputs=aligned_input, outputs=aligned_output)
demo.launch()