File size: 4,981 Bytes
85d3665
 
61b7033
ec138fe
 
85d3665
 
 
 
 
 
 
 
 
ec138fe
56df3cf
85d3665
ec138fe
85d3665
 
 
ec138fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85d3665
56df3cf
ec138fe
 
85d3665
ec138fe
85d3665
 
ec138fe
85d3665
ec138fe
 
85d3665
 
ec138fe
 
85d3665
ec138fe
85d3665
 
ec138fe
85d3665
 
 
 
ec138fe
 
85d3665
 
 
ec138fe
 
 
 
 
 
 
 
54df86c
85d3665
ec138fe
85d3665
 
 
 
 
 
 
 
 
 
 
ec138fe
85d3665
54df86c
85d3665
 
 
11e0ac2
85d3665
 
11e0ac2
85d3665
11e0ac2
 
 
54df86c
85d3665
11e0ac2
85d3665
ec138fe
 
11e0ac2
85d3665
 
 
 
ec138fe
 
85d3665
 
ec138fe
11e0ac2
ec138fe
 
85d3665
 
ec138fe
 
 
 
 
 
 
85d3665
11e0ac2
 
 
 
 
85d3665
 
569695c
ec138fe
569695c
85d3665
11e0ac2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import re
import random
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch

class ChatAssistant:
    def __init__(self):
        self.name = "AI Assistant"
        self.user_name = ""
        self.model_loaded = False
        self.generator = None
        self.tokenizer = None
        self.model = None
        self.load_model()
    
    def load_model(self):
        """Load the AI model"""
        if self.model_loaded:
            return True
            
        try:
            model_name = "microsoft/DialoGPT-medium"
            self.tokenizer = AutoTokenizer.from_pretrained(model_name)
            self.model = AutoModelForCausalLM.from_pretrained(
                model_name,
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                device_map="auto" if torch.cuda.is_available() else None,
                low_cpu_mem_usage=True
            )
            
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
            
            self.generator = pipeline(
                "text-generation",
                model=self.model,
                tokenizer=self.tokenizer,
                device=0 if torch.cuda.is_available() else -1,
                return_full_text=False
            )
            
            self.model_loaded = True
            return True
            
        except Exception as e:
            print(f"Could not load AI model: {str(e)}")
            return False
    
    def generate_response(self, query: str) -> str:
        """Generate response using the AI model"""
        if not self.model_loaded:
            return "Sorry, I'm having technical difficulties. Please try again later."
        
        try:
            prompt = f"""The following is a conversation with an AI assistant. The assistant is helpful, knowledgeable, and provides detailed answers.

User: {query}
AI:"""
            
            response = self.generator(
                prompt,
                max_new_tokens=300,
                do_sample=True,
                temperature=0.7,
                top_p=0.9,
                pad_token_id=self.tokenizer.eos_token_id,
                repetition_penalty=1.1,
                no_repeat_ngram_size=3
            )
            
            if response and len(response) > 0:
                return response[0]["generated_text"].strip()
                
        except Exception as e:
            print(f"AI generation error: {e}")
        
        return "I couldn't generate a response. Please try asking differently."

    def process_message(self, message: str) -> str:
        """Process user message and generate response"""
        name_response = self.get_user_name(message)
        if name_response:
            return name_response
        return self.generate_response(message)

    def get_user_name(self, message):
        """Check if user is introducing themselves"""
        name_patterns = [
            r"my name is (\w+)",
            r"i'm (\w+)",
            r"i am (\w+)",
            r"call me (\w+)"
        ]
        
        for pattern in name_patterns:
            match = re.search(pattern, message.lower())
            if match:
                self.user_name = match.group(1).capitalize()
                return f"Nice to meet you, {self.user_name}! How can I help you today?"
        return None

# Initialize the assistant
assistant = ChatAssistant()

def chat_response(message, chat_history):
    """Generate response for Gradio chat interface"""
    if not message.strip():
        return chat_history, ""
    
    bot_message = assistant.process_message(message)
    chat_history.append((message, bot_message))
    return chat_history, ""

def greet():
    return [(None, random.choice([
        "Hello! I'm your AI assistant. How can I help you today?",
        "Hi there! What would you like to know?",
        "Welcome! I'm ready to answer your questions."
    ]))]

# Create Gradio interface
def create_interface():
    with gr.Blocks(
        title="AI Assistant",
        theme=gr.themes.Soft()
    ) as iface:
        
        chatbot = gr.Chatbot(
            value=greet(),
            height=500,
            label="AI Assistant"
        )
        
        with gr.Row():
            msg = gr.Textbox(
                label="Type your message",
                placeholder="Ask me anything...",
                lines=2
            )
            submit_btn = gr.Button("Send", variant="primary")
        
        clear_btn = gr.Button("Clear Chat")
        
        msg.submit(chat_response, [msg, chatbot], [chatbot, msg])
        submit_btn.click(chat_response, [msg, chatbot], [chatbot, msg])
        clear_btn.click(lambda: (greet(), ""), outputs=[chatbot, msg])
    
    return iface

# Launch the interface
if __name__ == "__main__":
    interface = create_interface()
    interface.launch(share=True)