| | import torch |
| | import gradio as gr |
| | from transformers import BartTokenizer, BartForConditionalGeneration |
| | from transformers.modeling_outputs import BaseModelOutput |
| |
|
| | |
| | |
| | |
| | |
| | model_name = "facebook/bart-large" |
| |
|
| | print(f"Loading {model_name}...") |
| | tokenizer = BartTokenizer.from_pretrained(model_name) |
| | model = BartForConditionalGeneration.from_pretrained(model_name) |
| | model.eval() |
| |
|
| | |
| | |
| | |
| |
|
| | def text_to_embedding(text): |
| | inputs = tokenizer(text, return_tensors="pt") |
| | with torch.no_grad(): |
| | encoder_outputs = model.model.encoder(**inputs) |
| | return encoder_outputs.last_hidden_state |
| |
|
| | def embedding_to_text(embedding_tensor): |
| | encoder_outputs_wrapped = BaseModelOutput(last_hidden_state=embedding_tensor) |
| | |
| | with torch.no_grad(): |
| | generated_ids = model.generate( |
| | encoder_outputs=encoder_outputs_wrapped, |
| | |
| | |
| | max_length=50, |
| | num_beams=1, |
| | do_sample=False, |
| | temperature=1.0, |
| | repetition_penalty=1.0 |
| | ) |
| | |
| | decoded_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] |
| | return decoded_text |
| |
|
| | def run_weighted_mixing(text1, text2, mix_ratio): |
| | if not text1 or not text2: return "Enter sentences." |
| |
|
| | v1 = text_to_embedding(text1) |
| | v2 = text_to_embedding(text2) |
| |
|
| | |
| | min_len = min(v1.shape[1], v2.shape[1]) |
| | v1 = v1[:, :min_len, :] |
| | v2 = v2[:, :min_len, :] |
| |
|
| | |
| | v_mixed = (v1 * (1 - mix_ratio)) + (v2 * mix_ratio) |
| | |
| | return embedding_to_text(v_mixed) |
| |
|
| | |
| | |
| | |
| |
|
| | with gr.Blocks(title="BART-Large Vector Decoder", theme=gr.themes.Soft()) as demo: |
| | gr.Markdown("# 🦜 BART Strict Vector Decoder") |
| | gr.Markdown("This version uses `bart-large` with **Greedy Search** to force direct transcription instead of creative generation.") |
| |
|
| | with gr.Row(): |
| | t1 = gr.Textbox(label="Start Sentence", value="The dog is happy.") |
| | t2 = gr.Textbox(label="End Sentence", value="The cat is angry.") |
| | |
| | |
| | slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, label="Ratio (0 = Start, 1 = End)") |
| | |
| | btn_mix = gr.Button("Decode Vector") |
| | out = gr.Textbox(label="Decoded Text") |
| |
|
| | btn_mix.click(run_weighted_mixing, inputs=[t1, t2, slider], outputs=out) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch() |