navidfalah commited on
Commit
a2b8ec8
·
verified ·
1 Parent(s): dab87da

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +250 -210
app.py CHANGED
@@ -1,252 +1,292 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
4
- from peft import PeftModel
5
  import os
6
- from typing import Tuple, Optional
 
7
 
8
- # Configuration
9
- class Config:
10
- MODEL_PATH = "navidfalah/3ai" # Your HF model repo
11
- BASE_MODEL = "mistralai/Mistral-7B-Instruct-v0.1" # Mistral base model
12
- ADAPTER_PATH = "./model" # Local adapter path if needed
13
- MAX_NEW_TOKENS = 2000
14
- TEMPERATURE = 0.7
15
- TOP_P = 0.9
16
- MAX_INPUT_LENGTH = 1024
17
-
18
- # Global variables for model and tokenizer
19
- model = None
20
- tokenizer = None
21
 
22
- def load_model() -> Tuple[Optional[object], Optional[object]]:
23
- """Load the fine-tuned satisfaction analysis model."""
24
- global model, tokenizer
25
-
26
- if model is not None and tokenizer is not None:
27
- return model, tokenizer
28
-
 
 
 
 
 
 
 
 
29
  try:
30
- print("🔄 Loading Mistral model and tokenizer...")
31
-
32
- # Load tokenizer from base model (Mistral)
33
- tokenizer = AutoTokenizer.from_pretrained(Config.BASE_MODEL)
34
- if tokenizer.pad_token is None:
35
- tokenizer.pad_token = tokenizer.eos_token
36
- tokenizer.padding_side = "right"
37
-
38
- # Quantization config for efficient inference
39
- bnb_config = BitsAndBytesConfig(
40
- load_in_4bit=True,
41
- bnb_4bit_use_double_quant=True,
42
- bnb_4bit_quant_type="nf4",
43
- bnb_4bit_compute_dtype=torch.float16
44
- )
45
-
46
- # Load base Mistral model
47
- base_model = AutoModelForCausalLM.from_pretrained(
48
- Config.BASE_MODEL,
49
- quantization_config=bnb_config,
50
- device_map="auto",
51
- trust_remote_code=True,
52
- torch_dtype=torch.float16
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  )
54
-
55
- # Try loading adapter from HF repo first
 
56
  try:
57
- model = PeftModel.from_pretrained(
58
- base_model,
59
- Config.MODEL_PATH,
60
- is_trainable=False
 
61
  )
62
- print(" Loaded model from Hugging Face repo")
63
- except:
64
- # Fallback to local adapter if available
65
- if os.path.exists(Config.ADAPTER_PATH):
66
- model = PeftModel.from_pretrained(
67
- base_model,
68
- Config.ADAPTER_PATH,
69
- is_trainable=False
 
70
  )
71
- print(" Loaded model from local adapter")
72
- else:
73
- raise Exception("Could not load adapter from HF or local path")
74
-
75
- model.eval()
76
- print(" Mistral-7B model loaded successfully!")
77
- return model, tokenizer
78
-
79
- except Exception as e:
80
- print(f" Error loading model: {e}")
81
- return None, None
 
 
 
 
 
82
 
83
- def analyze_satisfaction(user_input: str) -> str:
84
- """Generate satisfaction analysis based on user input text."""
 
 
 
 
 
 
 
 
85
 
86
- # Load model if not already loaded
87
- model, tokenizer = load_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
- if model is None or tokenizer is None:
90
- return "Error: Could not load the model. Please try again later."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
- if not user_input.strip():
93
- return "⚠️ Please enter some text describing your life situation or what you'd like analyzed."
 
 
94
 
95
  try:
 
 
 
 
 
 
 
 
96
  # Tokenize input
97
  inputs = tokenizer(
98
- user_input,
99
- return_tensors="pt",
100
  truncation=True,
101
- max_length=Config.MAX_INPUT_LENGTH
 
102
  )
 
 
103
 
104
- # Move to GPU if available
105
- if torch.cuda.is_available():
106
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
 
 
107
 
108
  # Generate response
109
  with torch.no_grad():
 
 
 
110
  outputs = model.generate(
111
- **inputs,
112
- max_new_tokens=Config.MAX_NEW_TOKENS,
113
- temperature=Config.TEMPERATURE,
114
- top_p=Config.TOP_P,
115
  do_sample=True,
116
- pad_token_id=tokenizer.eos_token_id,
117
- eos_token_id=tokenizer.eos_token_id
 
 
 
118
  )
119
 
120
- # Decode response
121
- full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
122
 
123
- # Extract generated text (remove input prompt)
124
- generated_text = full_response[len(user_input):].strip()
125
 
126
- # Format the output
127
- formatted_output = "## 📊 Life Satisfaction Analysis\n\n"
128
- formatted_output += generated_text
 
129
 
130
- return formatted_output
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  except Exception as e:
133
- return f"Error during analysis: {str(e)}"
134
 
135
- # Example prompts for users
136
- EXAMPLE_PROMPTS = [
137
- "I'm a 29-year-old professional feeling burned out at work. My health is okay but I rarely exercise. Financially stable but not saving much. Great relationship with my partner. What's my life satisfaction score?",
138
- "Rate my life satisfaction: Work is stressful (3/10), health is good (7/10), finances are tight (4/10), relationships are excellent (9/10). Give me a comprehensive analysis.",
139
- "Analyze my satisfaction: Career going well, making good money, but no time for friends or hobbies. Always tired and stressed. How can I improve?",
140
- "I'm happy with my job and relationships but struggling with debt and health issues. Need advice on balancing everything.",
141
- "Just graduated, starting my career, living paycheck to paycheck, single but happy, very healthy and active. Analyze my life satisfaction."
142
- ]
143
 
144
- # Gradio Interface
145
- def create_interface():
146
- """Create the Gradio interface."""
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
- with gr.Blocks(title="Life Satisfaction Analysis", theme=gr.themes.Soft()) as demo:
149
- gr.Markdown(
150
- """
151
- # 🌟 AI Life Satisfaction Analyzer
152
-
153
- This AI-powered tool analyzes your life satisfaction based on your description of your current situation.
154
- Simply describe your life circumstances, challenges, and satisfaction levels across different areas.
155
-
156
- **The AI will analyze:**
157
- - Overall life satisfaction score
158
- - Balance across life domains (work, health, finances, relationships)
159
- - Personalized recommendations for improvement
160
- - Action plans and strategies
161
- """
162
- )
163
-
164
- with gr.Row():
165
- with gr.Column():
166
- # Input section
167
- input_text = gr.Textbox(
168
- label="📝 Describe Your Current Life Situation",
169
- placeholder="Tell me about your work, health, finances, relationships, and any other aspects of your life you'd like analyzed. You can include satisfaction ratings (1-10) or just describe how you feel about each area.",
170
- lines=8,
171
- max_lines=15
172
- )
173
-
174
- with gr.Row():
175
- analyze_btn = gr.Button("🔍 Analyze My Life Satisfaction", variant="primary", scale=2)
176
- clear_btn = gr.Button("🗑️ Clear", scale=1)
177
-
178
- # Examples section
179
- gr.Markdown("### 💡 Example Inputs")
180
- example_dropdown = gr.Dropdown(
181
- choices=EXAMPLE_PROMPTS,
182
- label="Select an example to try:",
183
- interactive=True
184
- )
185
-
186
- with gr.Row():
187
- with gr.Column():
188
- # Output section
189
- output = gr.Markdown(label="Analysis Results")
190
-
191
- # Event handlers
192
- analyze_btn.click(
193
- fn=analyze_satisfaction,
194
- inputs=input_text,
195
- outputs=output
196
- )
197
-
198
- clear_btn.click(
199
- fn=lambda: ("", ""),
200
- inputs=[],
201
- outputs=[input_text, output]
202
  )
203
-
204
- example_dropdown.change(
205
- fn=lambda x: x,
206
- inputs=example_dropdown,
207
- outputs=input_text
208
- )
209
-
210
- # Tips section
211
- with gr.Accordion("📖 Tips for Best Results", open=False):
212
- gr.Markdown(
213
- """
214
- **How to get the most accurate analysis:**
215
-
216
- 1. **Be specific** about your situation in each life area
217
- 2. **Include ratings** (1-10) if you want quantified analysis
218
- 3. **Mention your age** and life stage for context
219
- 4. **Describe challenges** you're facing
220
- 5. **Share your goals** or what you'd like to improve
221
-
222
- **Example format:**
223
- - Work: [Your situation and satisfaction level]
224
- - Health: [Physical and mental wellness status]
225
- - Finances: [Financial situation and concerns]
226
- - Relationships: [Social and romantic relationships]
227
- - Personal: [Hobbies, growth, fulfillment]
228
- """
229
- )
230
-
231
- # Footer
232
- gr.Markdown(
233
- """
234
- ---
235
- 💡 **Disclaimer:** This AI tool provides general insights based on the information you provide.
236
- For professional advice, please consult qualified experts in relevant fields.
237
-
238
- 🔒 **Privacy:** Your input is processed in real-time and not stored.
239
- """
240
  )
241
 
242
- return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
 
244
- # Launch the app
245
  if __name__ == "__main__":
246
- # Load model on startup
247
- print("🚀 Starting Life Satisfaction Analysis Tool...")
248
- load_model()
249
-
250
- # Create and launch interface
251
- demo = create_interface()
252
  demo.launch()
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ from huggingface_hub import login
5
  import os
6
+ import subprocess
7
+ import sys
8
 
9
+ print("Starting 3AI application...")
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ # Install required dependencies
12
+ print("Installing required dependencies...")
13
+ try:
14
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "sentencepiece", "protobuf", "peft", "--quiet"])
15
+ print("Dependencies installed successfully!")
16
+ except Exception as e:
17
+ print(f"Warning: Could not install dependencies: {e}")
18
+
19
+ # Import PEFT after installation
20
+ try:
21
+ from peft import PeftModel, PeftConfig
22
+ print("PEFT imported successfully!")
23
+ except ImportError as e:
24
+ print(f"Could not import PEFT: {e}")
25
+ print("Trying to install PEFT again...")
26
  try:
27
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "peft", "--force-reinstall"])
28
+ from peft import PeftModel, PeftConfig
29
+ print("PEFT installed and imported successfully!")
30
+ except Exception as e2:
31
+ print(f"Failed to install PEFT: {e2}")
32
+ print("Continuing without PEFT - will try alternative approach")
33
+ PeftModel = None
34
+ PeftConfig = None
35
+
36
+ # Login using the secret token
37
+ token = os.getenv("HF_TOKEN")
38
+ if token:
39
+ login(token=token)
40
+ print("Successfully logged in to Hugging Face!")
41
+
42
+ # Use your own Hugging Face model
43
+ original_mistral_model = "navidfalah/3ai" # Your model on Hugging Face
44
+ adapter_path = "./model" # Your local LoRA adapter directory (if available)
45
+
46
+ print(f"Loading original Mistral tokenizer from {original_mistral_model}...")
47
+ try:
48
+ # First try: Load with slow tokenizer from your model
49
+ tokenizer = AutoTokenizer.from_pretrained(
50
+ original_mistral_model,
51
+ use_fast=False, # Use slow tokenizer to avoid issues
52
+ force_download=True, # Force fresh download
53
+ resume_download=False
54
+ )
55
+ print("Your model tokenizer loaded successfully!")
56
+ except Exception as e:
57
+ print(f"Error loading tokenizer from your model: {e}")
58
+ try:
59
+ # Second try: Use original Mistral tokenizer
60
+ tokenizer = AutoTokenizer.from_pretrained(
61
+ "mistralai/Mistral-7B-Instruct-v0.1",
62
+ use_fast=False
63
  )
64
+ print("Original Mistral tokenizer loaded successfully!")
65
+ except Exception as e2:
66
+ print(f"Error with original Mistral: {e2}")
67
  try:
68
+ # Third try: Use different Mistral model version
69
+ print("Trying Mistral-7B-Instruct-v0.2...")
70
+ tokenizer = AutoTokenizer.from_pretrained(
71
+ "mistralai/Mistral-7B-Instruct-v0.2",
72
+ use_fast=False
73
  )
74
+ print("Mistral v0.2 tokenizer loaded successfully!")
75
+ except Exception as e3:
76
+ print(f"Error with Mistral v0.2: {e3}")
77
+ try:
78
+ # Fourth try: Use compatible tokenizer
79
+ print("Trying compatible tokenizer...")
80
+ tokenizer = AutoTokenizer.from_pretrained(
81
+ "microsoft/DialoGPT-medium",
82
+ use_fast=False
83
  )
84
+ print("Compatible tokenizer loaded successfully!")
85
+ except Exception as e4:
86
+ print(f"Error with compatible tokenizer: {e4}")
87
+ try:
88
+ # Fifth try: Use GPT-2 as fallback
89
+ print("Using GPT-2 as fallback...")
90
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
91
+ print("GPT-2 tokenizer loaded successfully!")
92
+ except Exception as e5:
93
+ print(f"Cannot load any tokenizer: {e5}")
94
+ print("Exiting - cannot proceed without tokenizer")
95
+ exit(1)
96
+
97
+ # Ensure tokenizer has proper tokens
98
+ if tokenizer.pad_token is None:
99
+ tokenizer.pad_token = tokenizer.eos_token
100
 
101
+ print(f"Loading your model from {original_mistral_model}...")
102
+ try:
103
+ # Load your model from Hugging Face
104
+ base_model = AutoModelForCausalLM.from_pretrained(
105
+ original_mistral_model,
106
+ torch_dtype=torch.float16,
107
+ device_map="auto",
108
+ low_cpu_mem_usage=True
109
+ )
110
+ print("Your model loaded successfully!")
111
 
112
+ # Check if PEFT is available and try to load local adapter
113
+ if PeftModel is not None and PeftConfig is not None:
114
+ try:
115
+ print(f"Trying to load local LoRA adapter from {adapter_path}...")
116
+ model = PeftModel.from_pretrained(
117
+ base_model,
118
+ adapter_path,
119
+ torch_dtype=torch.float16
120
+ )
121
+ print("Local LoRA adapter loaded successfully!")
122
+ except Exception as adapter_error:
123
+ print(f"Could not load local adapter: {adapter_error}")
124
+ print("Using your base model without additional adapter")
125
+ model = base_model
126
+ else:
127
+ print("PEFT not available - using your base model")
128
+ model = base_model
129
 
130
+ except Exception as e:
131
+ print(f"Error loading your model: {e}")
132
+ print("Trying to load original Mistral as fallback...")
133
+ try:
134
+ # Fallback to original Mistral
135
+ base_model = AutoModelForCausalLM.from_pretrained(
136
+ "mistralai/Mistral-7B-Instruct-v0.1",
137
+ torch_dtype=torch.float16,
138
+ device_map="auto",
139
+ low_cpu_mem_usage=True
140
+ )
141
+ print("Fallback Mistral model loaded!")
142
+ model = base_model
143
+ except Exception as e2:
144
+ print(f"Cannot load any model: {e2}")
145
+ print("Exiting - cannot proceed without model")
146
+ exit(1)
147
+
148
+ def chat_function(message):
149
+ if not message or not message.strip():
150
+ return "Please enter a message."
151
 
152
+ # Clean and limit input
153
+ message = message.strip()
154
+ if len(message) > 500:
155
+ return "Message too long! Please keep it under 500 characters."
156
 
157
  try:
158
+ # Use flexible prompt format based on tokenizer type
159
+ if hasattr(tokenizer, 'chat_template') or 'mistral' in tokenizer.name_or_path.lower():
160
+ # Use Mistral format if it's actually Mistral
161
+ prompt = f"<s>[INST] {message} [/INST]"
162
+ else:
163
+ # Use simple format for other tokenizers
164
+ prompt = f"User: {message}\nAssistant:"
165
+
166
  # Tokenize input
167
  inputs = tokenizer(
168
+ prompt,
169
+ return_tensors='pt',
170
  truncation=True,
171
+ max_length=400,
172
+ padding=True
173
  )
174
+ input_ids = inputs['input_ids']
175
+ attention_mask = inputs.get('attention_mask', None)
176
 
177
+ # Move to model device
178
+ device = next(model.parameters()).device
179
+ input_ids = input_ids.to(device)
180
+ if attention_mask is not None:
181
+ attention_mask = attention_mask.to(device)
182
 
183
  # Generate response
184
  with torch.no_grad():
185
+ if torch.cuda.is_available():
186
+ torch.cuda.empty_cache()
187
+
188
  outputs = model.generate(
189
+ input_ids,
190
+ max_new_tokens=200,
191
+ temperature=0.7,
 
192
  do_sample=True,
193
+ top_p=0.9,
194
+ pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id else tokenizer.eos_token_id,
195
+ eos_token_id=tokenizer.eos_token_id,
196
+ attention_mask=attention_mask,
197
+ repetition_penalty=1.1
198
  )
199
 
200
+ # Extract and clean response
201
+ if outputs.shape[1] > input_ids.shape[1]:
202
+ response_ids = outputs[0][input_ids.shape[1]:]
203
+ response = tokenizer.decode(response_ids, skip_special_tokens=True)
204
+ else:
205
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
206
+ response = response.replace(prompt, "").strip()
207
 
208
+ # Clean up response
209
+ response = response.strip()
210
 
211
+ # Remove prompt artifacts
212
+ for artifact in ["[/INST]", "[INST]", "Assistant:", "User:", "Human:"]:
213
+ if artifact in response:
214
+ response = response.split(artifact)[-1].strip()
215
 
216
+ # Remove input if it appears in response
217
+ if message.lower() in response.lower():
218
+ response = response.replace(message, "").strip()
219
+
220
+ # Ensure reasonable length
221
+ if len(response) > 800:
222
+ response = response[:800] + "..."
223
+
224
+ # Fallback if empty
225
+ if len(response.strip()) < 3:
226
+ response = "I understand. How can I help you?"
227
+
228
+ return response
229
 
230
  except Exception as e:
231
+ return f"Error: {str(e)}"
232
 
233
+ def clear_chat():
234
+ return ""
 
 
 
 
 
 
235
 
236
+ # Simple custom CSS
237
+ css = """
238
+ .gradio-container {
239
+ max-width: 700px !important;
240
+ margin: auto !important;
241
+ }
242
+ """
243
+
244
+ # Create interface
245
+ with gr.Blocks(title="3AI - Text Generation", css=css, theme=gr.themes.Default()) as demo:
246
+ # Header
247
+ gr.Markdown("""
248
+ # 🤖 3AI Text Generator
249
+ *Simple text-to-text generation with your navidfalah/3ai model*
250
+ """)
251
 
252
+ # Input
253
+ with gr.Row():
254
+ input_text = gr.Textbox(
255
+ placeholder="Enter your text here...",
256
+ label="Input Text",
257
+ lines=2,
258
+ max_lines=3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
  )
260
+
261
+ # Generate button
262
+ with gr.Row():
263
+ generate_btn = gr.Button("Generate", variant="primary", size="lg")
264
+
265
+ # Output
266
+ with gr.Row():
267
+ output_text = gr.Textbox(
268
+ label="Generated Text",
269
+ lines=6,
270
+ max_lines=10,
271
+ interactive=False,
272
+ placeholder="Generated text will appear here..."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
  )
274
 
275
+ # Event handlers
276
+ generate_btn.click(
277
+ fn=chat_function,
278
+ inputs=input_text,
279
+ outputs=output_text
280
+ )
281
+
282
+ input_text.submit(
283
+ fn=chat_function,
284
+ inputs=input_text,
285
+ outputs=output_text
286
+ )
287
+
288
+ # Footer
289
+ gr.Markdown("---\n*navidfalah/3ai • Simple Text Generation*")
290
 
 
291
  if __name__ == "__main__":
 
 
 
 
 
 
292
  demo.launch()