File size: 2,132 Bytes
043d0b7
21f0fe1
 
043d0b7
653d26b
21f0fe1
 
 
 
043d0b7
653d26b
 
fb56cfc
 
c4e54cc
043d0b7
21f0fe1
 
653d26b
9bec465
96a9ff4
 
653d26b
96a9ff4
 
21f0fe1
96a9ff4
 
 
 
9bec465
 
 
 
 
 
653d26b
21f0fe1
 
9bec465
 
 
 
 
653d26b
9bec465
043d0b7
653d26b
21f0fe1
 
9bec465
21f0fe1
 
0c3f035
21f0fe1
 
 
9bec465
0c3f035
 
653d26b
21f0fe1
 
fb56cfc
653d26b
21f0fe1
 
fb56cfc
21f0fe1
0c3f035
653d26b
 
fb56cfc
 
0efc926
653d26b
043d0b7
 
653d26b
9bec465
 
 
96a9ff4
9bec465
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# Load model
model_name = "microsoft/DialoGPT-medium"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)


# Chat function
def reply(message, history):
    if not message.strip():
        return "Please enter a message."

    chat_history_ids = None

    # Handle previous conversation
    for msg in history:
        content = msg["content"]

        # Fix: if content is list → convert to string
        if isinstance(content, list):
            content = " ".join([str(x) for x in content])

        if not isinstance(content, str):
            continue

        ids = tokenizer.encode(content + tokenizer.eos_token, return_tensors="pt")

        if chat_history_ids is None:
            chat_history_ids = ids
        else:
            chat_history_ids = torch.cat([chat_history_ids, ids], dim=-1)

    # Current message
    new_input_ids = tokenizer.encode(message + tokenizer.eos_token, return_tensors="pt")

    if chat_history_ids is not None:
        input_ids = torch.cat([chat_history_ids, new_input_ids], dim=-1)
    else:
        input_ids = new_input_ids

    # Attention mask fix
    attention_mask = torch.ones_like(input_ids)

    # Generate response
    output_ids = model.generate(
        input_ids,
        attention_mask=attention_mask,
        max_length=1000,
        pad_token_id=tokenizer.eos_token_id,
        do_sample=True,
        top_k=50,
        top_p=0.95,
        temperature=0.7,
        repetition_penalty=1.2
    )

    # Extract only new response
    response_ids = output_ids[:, input_ids.shape[-1]:]
    response = tokenizer.decode(response_ids[0], skip_special_tokens=True)

    # Fallback
    if response.strip() == "":
        response = "I'm here! How can I help you?"

    return response


# UI
demo = gr.ChatInterface(
    fn=reply,
    title="💬 Smart Dialogue System",
    description="Full conversation chatbot using DialoGPT"
)

# Launch
demo.launch(
    server_name="0.0.0.0",
    server_port=7860,
    ssr_mode=False
)