File size: 3,640 Bytes
8a9a1e4
 
 
 
 
9126f59
a8bef18
 
9126f59
 
 
a8bef18
9126f59
a8bef18
9126f59
a8bef18
8a9a1e4
9126f59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a8bef18
9126f59
 
a8bef18
9126f59
 
 
a8bef18
9126f59
 
 
 
 
 
 
 
a8bef18
9126f59
 
 
 
 
 
 
 
 
 
 
a8bef18
9126f59
 
 
a8bef18
9126f59
 
8a9a1e4
9126f59
 
 
 
 
 
 
a8bef18
9126f59
 
 
 
 
 
 
 
 
 
a8bef18
9126f59
 
 
 
 
 
a8bef18
9126f59
 
 
 
 
 
a8bef18
9126f59
 
 
 
 
 
a8bef18
9126f59
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import gradio as gr
from transformers import AutoTokenizer
from peft import AutoPeftModelForSeq2SeqLM
import torch

# Model configuration
repo_id = "sairika/FLAN-T5-Base-dialogsum-lora"

# Sample dialogues
examples = [
    "Alice: How was your meeting today?\nBob: It went great! The client approved our proposal.\nAlice: That's wonderful news!\nBob: Yes, we start next Monday.",
    
    "Customer: My internet isn't working.\nSupport: I can help with that. Have you tried restarting your router?\nCustomer: Yes, but it's still not working.\nSupport: Let me check your connection status.",
    
    "Teacher: Did you finish your homework?\nStudent: Almost done, just the math problems left.\nTeacher: Need help with anything?\nStudent: Yes, I'm stuck on question 5."
]

def load_model():
    """Load the model and tokenizer"""
    try:
        tokenizer = AutoTokenizer.from_pretrained(repo_id)
        model = AutoPeftModelForSeq2SeqLM.from_pretrained(
            repo_id, 
            device_map="auto", 
            torch_dtype=torch.bfloat16
        )
        return tokenizer, model, True
    except Exception as e:
        print(f"Error loading model: {e}")
        return None, None, False

def summarize_dialogue(dialogue):
    """Generate summary for the dialogue"""
    if not dialogue.strip():
        return "Please enter a dialogue to summarize."
    
    if not model_loaded:
        return "Error: Model not loaded. Please check the logs."
    
    try:
        # Create prompt
        prompt = f"Summarize the following conversation.\n\n{dialogue}\n\nSummary: "
        
        # Tokenize input
        inputs = tokenizer(
            prompt,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=512
        ).to(model.device)
        
        # Generate summary
        model.eval()
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=128,
                num_beams=4,
                length_penalty=0.6,
                early_stopping=True,
                do_sample=False
            )
        
        # Decode and return
        summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
        return summary
        
    except Exception as e:
        return f"Error generating summary: {str(e)}"

# Load model
tokenizer, model, model_loaded = load_model()

# Create Gradio interface
with gr.Blocks(title="Dialogue Summarization") as demo:
    gr.Markdown("# 🤖 Dialogue Summarization")
    gr.Markdown("Enter a conversation below to generate an AI summary using FLAN-T5.")
    
    with gr.Row():
        with gr.Column():
            dialogue_input = gr.Textbox(
                label="Dialogue",
                placeholder="Enter your conversation here...",
                lines=8
            )
            
            submit_btn = gr.Button("Generate Summary", variant="primary")
            clear_btn = gr.Button("Clear")
        
        with gr.Column():
            summary_output = gr.Textbox(
                label="Summary",
                lines=6,
                show_copy_button=True
            )
    
    # Examples
    gr.Examples(
        examples=examples,
        inputs=dialogue_input,
        label="Try these examples:"
    )
    
    # Event handlers
    submit_btn.click(
        fn=summarize_dialogue,
        inputs=dialogue_input,
        outputs=summary_output
    )
    
    clear_btn.click(
        fn=lambda: ("", ""),
        outputs=[dialogue_input, summary_output]
    )

# Launch
if __name__ == "__main__":
    demo.launch(share=True)