navidfalah commited on
Commit
b35a0e0
Β·
verified Β·
1 Parent(s): 3661def

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +178 -53
app.py CHANGED
@@ -12,81 +12,206 @@ if token:
12
  login(token=token)
13
  print("Successfully logged in to Hugging Face!")
14
 
15
- # Load the base model and tokenizer with better error handling
16
- model_name = "mistralai/Mistral-7B-Instruct-v0.1"
17
  print("Loading tokenizer...")
18
 
19
  try:
20
- # Try loading with use_fast=False to avoid tokenizer issues
21
- tokenizer = AutoTokenizer.from_pretrained(
22
- model_name,
23
- use_fast=False,
24
- trust_remote_code=True
 
25
  )
 
 
26
  except Exception as e:
27
- print(f"Error loading tokenizer: {e}")
28
- print("Trying alternative tokenizer...")
29
- # Fallback to a known working tokenizer
30
- tokenizer = AutoTokenizer.from_pretrained(
31
- "huggingface/CodeBERTa-small-v1",
32
- use_fast=False
33
- )
34
 
35
  # Set pad token if not set
36
  if tokenizer.pad_token is None:
37
  tokenizer.pad_token = tokenizer.eos_token
38
 
39
- print("Loading model...")
40
- # Load base model with better memory management
41
- model = AutoModelForCausalLM.from_pretrained(
42
- model_name,
43
- torch_dtype=torch.float16,
44
- device_map="auto",
45
- offload_folder="./offload" # Add offload directory
46
- )
47
-
48
- print("Model loaded successfully!")
49
-
50
- def chat_function(message, history):
51
  try:
52
- # Simple format for Mistral
53
- prompt = f"<s>[INST] {message} [/INST]"
 
54
 
55
- # Tokenize
56
- inputs = tokenizer(
57
- prompt,
58
- return_tensors="pt",
59
- truncation=True,
60
- max_length=1024
61
- )
62
 
63
- # Generate
64
  with torch.no_grad():
65
- outputs = model.generate(
66
- inputs.input_ids,
67
- max_new_tokens=256,
68
- temperature=0.7,
69
  do_sample=True,
70
- pad_token_id=tokenizer.eos_token_id
 
 
 
71
  )
72
 
73
  # Decode response
74
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
75
- response = response.replace(prompt, "").strip()
76
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  return response
78
 
79
  except Exception as e:
80
- return f"Error: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
- # Create the simplest possible interface
83
- demo = gr.Interface(
84
- fn=chat_function,
85
- inputs=gr.Textbox(placeholder="Type your message here..."),
86
- outputs=gr.Textbox(),
87
- title="3AI Custom Mistral Model",
88
- description="Chat with the Mistral 7B model"
89
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  if __name__ == "__main__":
92
- demo.launch(share=True)
 
12
  login(token=token)
13
  print("Successfully logged in to Hugging Face!")
14
 
15
+ # Use a more reliable model that works well in Spaces
16
+ model_name = "microsoft/DialoGPT-medium"
17
  print("Loading tokenizer...")
18
 
19
  try:
20
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
21
+ print("Loading model...")
22
+ model = AutoModelForCausalLM.from_pretrained(
23
+ model_name,
24
+ torch_dtype=torch.float16,
25
+ device_map="auto"
26
  )
27
+ print("Model loaded successfully!")
28
+
29
  except Exception as e:
30
+ print(f"Error loading model: {e}")
31
+ print("Using a simpler fallback...")
32
+ # Fallback to an even simpler model
33
+ model_name = "gpt2"
34
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
35
+ model = AutoModelForCausalLM.from_pretrained(model_name)
36
+ print("Fallback model loaded!")
37
 
38
  # Set pad token if not set
39
  if tokenizer.pad_token is None:
40
  tokenizer.pad_token = tokenizer.eos_token
41
 
42
+ def chat_function(message, temperature, max_length):
43
+ if not message or not message.strip():
44
+ return "Please enter a message to get started! πŸ’¬"
45
+
46
+ # Limit input length
47
+ if len(message) > 500:
48
+ return "⚠️ Message too long! Please keep it under 500 characters."
49
+
 
 
 
 
50
  try:
51
+ # Encode the input with length limit
52
+ message = message.strip()[:200] # Hard limit on input
53
+ input_ids = tokenizer.encode(message + tokenizer.eos_token, return_tensors='pt')
54
 
55
+ # Ensure input isn't too long
56
+ if input_ids.shape[-1] > 100:
57
+ input_ids = input_ids[:, :100]
 
 
 
 
58
 
59
+ # Generate response with controlled parameters
60
  with torch.no_grad():
61
+ chat_history_ids = model.generate(
62
+ input_ids,
63
+ max_length=min(input_ids.shape[-1] + max_length, 200), # Total max 200 tokens
64
+ temperature=temperature,
65
  do_sample=True,
66
+ pad_token_id=tokenizer.eos_token_id,
67
+ num_return_sequences=1,
68
+ repetition_penalty=1.1,
69
+ top_p=0.9
70
  )
71
 
72
  # Decode response
73
+ response = tokenizer.decode(chat_history_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
 
74
 
75
+ # Clean up the response
76
+ response = response.strip()
77
+ if not response:
78
+ responses = [
79
+ "I'm here to help! What would you like to know? πŸ€”",
80
+ "How can I assist you today? 😊",
81
+ "What's on your mind? πŸ’­",
82
+ "I'm ready to chat! What would you like to discuss? πŸ—£οΈ"
83
+ ]
84
+ import random
85
+ response = random.choice(responses)
86
+
87
  return response
88
 
89
  except Exception as e:
90
+ return f"🚫 Sorry, I encountered an error: {str(e)}"
91
+
92
+ def clear_chat():
93
+ return "", ""
94
+
95
+ # Custom CSS for better styling
96
+ css = """
97
+ .gradio-container {
98
+ max-width: 800px !important;
99
+ margin: auto !important;
100
+ }
101
+ .chat-container {
102
+ border-radius: 15px;
103
+ box-shadow: 0 4px 15px rgba(0,0,0,0.1);
104
+ }
105
+ .input-container {
106
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
107
+ padding: 20px;
108
+ border-radius: 15px;
109
+ margin-bottom: 10px;
110
+ }
111
+ .output-container {
112
+ background: #f8f9fa;
113
+ border-radius: 15px;
114
+ padding: 20px;
115
+ }
116
+ """
117
 
118
+ # Create a refined interface
119
+ with gr.Blocks(title="3AI Chat Bot", css=css, theme=gr.themes.Soft()) as demo:
120
+ gr.HTML("""
121
+ <div style="text-align: center; padding: 20px;">
122
+ <h1 style="color: #2c3e50; margin-bottom: 10px;">πŸ€– 3AI Chat Bot</h1>
123
+ <p style="color: #7f8c8d; font-size: 16px;">Powered by AI β€’ Fast β€’ Intelligent β€’ Helpful</p>
124
+ </div>
125
+ """)
126
+
127
+ with gr.Row():
128
+ with gr.Column(scale=2):
129
+ with gr.Group(elem_classes="input-container"):
130
+ gr.Markdown("### πŸ’¬ **Chat with AI**")
131
+ message_input = gr.Textbox(
132
+ placeholder="Type your message here... (max 500 characters)",
133
+ label="Your Message",
134
+ lines=3,
135
+ max_lines=5,
136
+ show_label=False,
137
+ container=False
138
+ )
139
+
140
+ with gr.Row():
141
+ submit_btn = gr.Button("πŸš€ Send", variant="primary", scale=2)
142
+ clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary", scale=1)
143
+
144
+ with gr.Column(scale=1):
145
+ with gr.Group():
146
+ gr.Markdown("### βš™οΈ **Settings**")
147
+ temperature = gr.Slider(
148
+ minimum=0.1,
149
+ maximum=1.0,
150
+ value=0.7,
151
+ step=0.1,
152
+ label="🌑️ Creativity",
153
+ info="Higher = more creative"
154
+ )
155
+ max_length = gr.Slider(
156
+ minimum=20,
157
+ maximum=100,
158
+ value=50,
159
+ step=10,
160
+ label="πŸ“ Response Length",
161
+ info="Max tokens to generate"
162
+ )
163
+
164
+ with gr.Group(elem_classes="output-container"):
165
+ gr.Markdown("### πŸ€– **AI Response**")
166
+ response_output = gr.Textbox(
167
+ label="",
168
+ lines=6,
169
+ max_lines=10,
170
+ interactive=False,
171
+ show_label=False,
172
+ container=False,
173
+ placeholder="AI responses will appear here... ✨"
174
+ )
175
+
176
+ # Character counter
177
+ char_count = gr.HTML("<p style='text-align: right; color: #7f8c8d; font-size: 12px;'>0/500 characters</p>")
178
+
179
+ # Event handlers
180
+ submit_btn.click(
181
+ fn=chat_function,
182
+ inputs=[message_input, temperature, max_length],
183
+ outputs=response_output
184
+ )
185
+
186
+ message_input.submit(
187
+ fn=chat_function,
188
+ inputs=[message_input, temperature, max_length],
189
+ outputs=response_output
190
+ )
191
+
192
+ clear_btn.click(
193
+ fn=clear_chat,
194
+ outputs=[message_input, response_output]
195
+ )
196
+
197
+ # Update character counter
198
+ def update_char_count(text):
199
+ count = len(text) if text else 0
200
+ color = "#e74c3c" if count > 500 else "#7f8c8d"
201
+ return f"<p style='text-align: right; color: {color}; font-size: 12px;'>{count}/500 characters</p>"
202
+
203
+ message_input.change(
204
+ fn=update_char_count,
205
+ inputs=message_input,
206
+ outputs=char_count
207
+ )
208
+
209
+ # Add footer
210
+ gr.HTML("""
211
+ <div style="text-align: center; padding: 20px; color: #7f8c8d; font-size: 12px;">
212
+ <p>Made with ❀️ using Gradio β€’ Powered by Transformers</p>
213
+ </div>
214
+ """)
215
 
216
  if __name__ == "__main__":
217
+ demo.launch()