navidfalah commited on
Commit
dab87da
Β·
verified Β·
1 Parent(s): e12356b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +210 -250
app.py CHANGED
@@ -1,292 +1,252 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoTokenizer, AutoModelForCausalLM
4
- from huggingface_hub import login
5
  import os
6
- import subprocess
7
- import sys
8
 
9
- print("Starting 3AI application...")
10
-
11
- # Install required dependencies
12
- print("Installing required dependencies...")
13
- try:
14
- subprocess.check_call([sys.executable, "-m", "pip", "install", "sentencepiece", "protobuf", "peft", "--quiet"])
15
- print("Dependencies installed successfully!")
16
- except Exception as e:
17
- print(f"Warning: Could not install dependencies: {e}")
18
-
19
- # Import PEFT after installation
20
- try:
21
- from peft import PeftModel, PeftConfig
22
- print("PEFT imported successfully!")
23
- except ImportError as e:
24
- print(f"Could not import PEFT: {e}")
25
- print("Trying to install PEFT again...")
26
- try:
27
- subprocess.check_call([sys.executable, "-m", "pip", "install", "peft", "--force-reinstall"])
28
- from peft import PeftModel, PeftConfig
29
- print("PEFT installed and imported successfully!")
30
- except Exception as e2:
31
- print(f"Failed to install PEFT: {e2}")
32
- print("Continuing without PEFT - will try alternative approach")
33
- PeftModel = None
34
- PeftConfig = None
35
-
36
- # Login using the secret token
37
- token = os.getenv("HF_TOKEN")
38
- if token:
39
- login(token=token)
40
- print("Successfully logged in to Hugging Face!")
41
-
42
- # Use your own Hugging Face model
43
- original_mistral_model = "navidfalah/3ai" # Your model on Hugging Face
44
- adapter_path = "./model" # Your local LoRA adapter directory (if available)
45
 
46
- print(f"Loading original Mistral tokenizer from {original_mistral_model}...")
47
- try:
48
- # First try: Load with slow tokenizer from your model
49
- tokenizer = AutoTokenizer.from_pretrained(
50
- original_mistral_model,
51
- use_fast=False, # Use slow tokenizer to avoid issues
52
- force_download=True, # Force fresh download
53
- resume_download=False
54
- )
55
- print("Your model tokenizer loaded successfully!")
56
- except Exception as e:
57
- print(f"Error loading tokenizer from your model: {e}")
58
  try:
59
- # Second try: Use original Mistral tokenizer
60
- tokenizer = AutoTokenizer.from_pretrained(
61
- "mistralai/Mistral-7B-Instruct-v0.1",
62
- use_fast=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  )
64
- print("Original Mistral tokenizer loaded successfully!")
65
- except Exception as e2:
66
- print(f"Error with original Mistral: {e2}")
67
  try:
68
- # Third try: Use different Mistral model version
69
- print("Trying Mistral-7B-Instruct-v0.2...")
70
- tokenizer = AutoTokenizer.from_pretrained(
71
- "mistralai/Mistral-7B-Instruct-v0.2",
72
- use_fast=False
73
  )
74
- print("Mistral v0.2 tokenizer loaded successfully!")
75
- except Exception as e3:
76
- print(f"Error with Mistral v0.2: {e3}")
77
- try:
78
- # Fourth try: Use compatible tokenizer
79
- print("Trying compatible tokenizer...")
80
- tokenizer = AutoTokenizer.from_pretrained(
81
- "microsoft/DialoGPT-medium",
82
- use_fast=False
83
  )
84
- print("Compatible tokenizer loaded successfully!")
85
- except Exception as e4:
86
- print(f"Error with compatible tokenizer: {e4}")
87
- try:
88
- # Fifth try: Use GPT-2 as fallback
89
- print("Using GPT-2 as fallback...")
90
- tokenizer = AutoTokenizer.from_pretrained("gpt2")
91
- print("GPT-2 tokenizer loaded successfully!")
92
- except Exception as e5:
93
- print(f"Cannot load any tokenizer: {e5}")
94
- print("Exiting - cannot proceed without tokenizer")
95
- exit(1)
96
-
97
- # Ensure tokenizer has proper tokens
98
- if tokenizer.pad_token is None:
99
- tokenizer.pad_token = tokenizer.eos_token
100
 
101
- print(f"Loading your model from {original_mistral_model}...")
102
- try:
103
- # Load your model from Hugging Face
104
- base_model = AutoModelForCausalLM.from_pretrained(
105
- original_mistral_model,
106
- torch_dtype=torch.float16,
107
- device_map="auto",
108
- low_cpu_mem_usage=True
109
- )
110
- print("Your model loaded successfully!")
111
 
112
- # Check if PEFT is available and try to load local adapter
113
- if PeftModel is not None and PeftConfig is not None:
114
- try:
115
- print(f"Trying to load local LoRA adapter from {adapter_path}...")
116
- model = PeftModel.from_pretrained(
117
- base_model,
118
- adapter_path,
119
- torch_dtype=torch.float16
120
- )
121
- print("Local LoRA adapter loaded successfully!")
122
- except Exception as adapter_error:
123
- print(f"Could not load local adapter: {adapter_error}")
124
- print("Using your base model without additional adapter")
125
- model = base_model
126
- else:
127
- print("PEFT not available - using your base model")
128
- model = base_model
129
 
130
- except Exception as e:
131
- print(f"Error loading your model: {e}")
132
- print("Trying to load original Mistral as fallback...")
133
- try:
134
- # Fallback to original Mistral
135
- base_model = AutoModelForCausalLM.from_pretrained(
136
- "mistralai/Mistral-7B-Instruct-v0.1",
137
- torch_dtype=torch.float16,
138
- device_map="auto",
139
- low_cpu_mem_usage=True
140
- )
141
- print("Fallback Mistral model loaded!")
142
- model = base_model
143
- except Exception as e2:
144
- print(f"Cannot load any model: {e2}")
145
- print("Exiting - cannot proceed without model")
146
- exit(1)
147
-
148
- def chat_function(message):
149
- if not message or not message.strip():
150
- return "Please enter a message."
151
 
152
- # Clean and limit input
153
- message = message.strip()
154
- if len(message) > 500:
155
- return "Message too long! Please keep it under 500 characters."
156
 
157
  try:
158
- # Use flexible prompt format based on tokenizer type
159
- if hasattr(tokenizer, 'chat_template') or 'mistral' in tokenizer.name_or_path.lower():
160
- # Use Mistral format if it's actually Mistral
161
- prompt = f"<s>[INST] {message} [/INST]"
162
- else:
163
- # Use simple format for other tokenizers
164
- prompt = f"User: {message}\nAssistant:"
165
-
166
  # Tokenize input
167
  inputs = tokenizer(
168
- prompt,
169
- return_tensors='pt',
170
  truncation=True,
171
- max_length=400,
172
- padding=True
173
  )
174
- input_ids = inputs['input_ids']
175
- attention_mask = inputs.get('attention_mask', None)
176
 
177
- # Move to model device
178
- device = next(model.parameters()).device
179
- input_ids = input_ids.to(device)
180
- if attention_mask is not None:
181
- attention_mask = attention_mask.to(device)
182
 
183
  # Generate response
184
  with torch.no_grad():
185
- if torch.cuda.is_available():
186
- torch.cuda.empty_cache()
187
-
188
  outputs = model.generate(
189
- input_ids,
190
- max_new_tokens=200,
191
- temperature=0.7,
 
192
  do_sample=True,
193
- top_p=0.9,
194
- pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id else tokenizer.eos_token_id,
195
- eos_token_id=tokenizer.eos_token_id,
196
- attention_mask=attention_mask,
197
- repetition_penalty=1.1
198
  )
199
 
200
- # Extract and clean response
201
- if outputs.shape[1] > input_ids.shape[1]:
202
- response_ids = outputs[0][input_ids.shape[1]:]
203
- response = tokenizer.decode(response_ids, skip_special_tokens=True)
204
- else:
205
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
206
- response = response.replace(prompt, "").strip()
207
 
208
- # Clean up response
209
- response = response.strip()
210
 
211
- # Remove prompt artifacts
212
- for artifact in ["[/INST]", "[INST]", "Assistant:", "User:", "Human:"]:
213
- if artifact in response:
214
- response = response.split(artifact)[-1].strip()
215
 
216
- # Remove input if it appears in response
217
- if message.lower() in response.lower():
218
- response = response.replace(message, "").strip()
219
-
220
- # Ensure reasonable length
221
- if len(response) > 800:
222
- response = response[:800] + "..."
223
-
224
- # Fallback if empty
225
- if len(response.strip()) < 3:
226
- response = "I understand. How can I help you?"
227
-
228
- return response
229
 
230
  except Exception as e:
231
- return f"Error: {str(e)}"
232
 
233
- def clear_chat():
234
- return ""
 
 
 
 
 
 
235
 
236
- # Simple custom CSS
237
- css = """
238
- .gradio-container {
239
- max-width: 700px !important;
240
- margin: auto !important;
241
- }
242
- """
243
-
244
- # Create interface
245
- with gr.Blocks(title="3AI - Text Generation", css=css, theme=gr.themes.Default()) as demo:
246
- # Header
247
- gr.Markdown("""
248
- # πŸ€– 3AI Text Generator
249
- *Simple text-to-text generation with your navidfalah/3ai model*
250
- """)
251
 
252
- # Input
253
- with gr.Row():
254
- input_text = gr.Textbox(
255
- placeholder="Enter your text here...",
256
- label="Input Text",
257
- lines=2,
258
- max_lines=3
 
 
 
 
 
 
 
259
  )
260
-
261
- # Generate button
262
- with gr.Row():
263
- generate_btn = gr.Button("Generate", variant="primary", size="lg")
264
-
265
- # Output
266
- with gr.Row():
267
- output_text = gr.Textbox(
268
- label="Generated Text",
269
- lines=6,
270
- max_lines=10,
271
- interactive=False,
272
- placeholder="Generated text will appear here..."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
  )
274
 
275
- # Event handlers
276
- generate_btn.click(
277
- fn=chat_function,
278
- inputs=input_text,
279
- outputs=output_text
280
- )
281
-
282
- input_text.submit(
283
- fn=chat_function,
284
- inputs=input_text,
285
- outputs=output_text
286
- )
287
-
288
- # Footer
289
- gr.Markdown("---\n*navidfalah/3ai β€’ Simple Text Generation*")
290
 
 
291
  if __name__ == "__main__":
 
 
 
 
 
 
292
  demo.launch()
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
4
+ from peft import PeftModel
5
  import os
6
+ from typing import Tuple, Optional
 
7
 
8
+ # Configuration
9
+ class Config:
10
+ MODEL_PATH = "navidfalah/3ai" # Your HF model repo
11
+ BASE_MODEL = "mistralai/Mistral-7B-Instruct-v0.1" # Mistral base model
12
+ ADAPTER_PATH = "./model" # Local adapter path if needed
13
+ MAX_NEW_TOKENS = 2000
14
+ TEMPERATURE = 0.7
15
+ TOP_P = 0.9
16
+ MAX_INPUT_LENGTH = 1024
17
+
18
+ # Global variables for model and tokenizer
19
+ model = None
20
+ tokenizer = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ def load_model() -> Tuple[Optional[object], Optional[object]]:
23
+ """Load the fine-tuned satisfaction analysis model."""
24
+ global model, tokenizer
25
+
26
+ if model is not None and tokenizer is not None:
27
+ return model, tokenizer
28
+
 
 
 
 
 
29
  try:
30
+ print("πŸ”„ Loading Mistral model and tokenizer...")
31
+
32
+ # Load tokenizer from base model (Mistral)
33
+ tokenizer = AutoTokenizer.from_pretrained(Config.BASE_MODEL)
34
+ if tokenizer.pad_token is None:
35
+ tokenizer.pad_token = tokenizer.eos_token
36
+ tokenizer.padding_side = "right"
37
+
38
+ # Quantization config for efficient inference
39
+ bnb_config = BitsAndBytesConfig(
40
+ load_in_4bit=True,
41
+ bnb_4bit_use_double_quant=True,
42
+ bnb_4bit_quant_type="nf4",
43
+ bnb_4bit_compute_dtype=torch.float16
44
+ )
45
+
46
+ # Load base Mistral model
47
+ base_model = AutoModelForCausalLM.from_pretrained(
48
+ Config.BASE_MODEL,
49
+ quantization_config=bnb_config,
50
+ device_map="auto",
51
+ trust_remote_code=True,
52
+ torch_dtype=torch.float16
53
  )
54
+
55
+ # Try loading adapter from HF repo first
 
56
  try:
57
+ model = PeftModel.from_pretrained(
58
+ base_model,
59
+ Config.MODEL_PATH,
60
+ is_trainable=False
 
61
  )
62
+ print("βœ… Loaded model from Hugging Face repo")
63
+ except:
64
+ # Fallback to local adapter if available
65
+ if os.path.exists(Config.ADAPTER_PATH):
66
+ model = PeftModel.from_pretrained(
67
+ base_model,
68
+ Config.ADAPTER_PATH,
69
+ is_trainable=False
 
70
  )
71
+ print("βœ… Loaded model from local adapter")
72
+ else:
73
+ raise Exception("Could not load adapter from HF or local path")
74
+
75
+ model.eval()
76
+ print("βœ… Mistral-7B model loaded successfully!")
77
+ return model, tokenizer
78
+
79
+ except Exception as e:
80
+ print(f"❌ Error loading model: {e}")
81
+ return None, None
 
 
 
 
 
82
 
83
+ def analyze_satisfaction(user_input: str) -> str:
84
+ """Generate satisfaction analysis based on user input text."""
 
 
 
 
 
 
 
 
85
 
86
+ # Load model if not already loaded
87
+ model, tokenizer = load_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
+ if model is None or tokenizer is None:
90
+ return "❌ Error: Could not load the model. Please try again later."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
+ if not user_input.strip():
93
+ return "⚠️ Please enter some text describing your life situation or what you'd like analyzed."
 
 
94
 
95
  try:
 
 
 
 
 
 
 
 
96
  # Tokenize input
97
  inputs = tokenizer(
98
+ user_input,
99
+ return_tensors="pt",
100
  truncation=True,
101
+ max_length=Config.MAX_INPUT_LENGTH
 
102
  )
 
 
103
 
104
+ # Move to GPU if available
105
+ if torch.cuda.is_available():
106
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
 
 
107
 
108
  # Generate response
109
  with torch.no_grad():
 
 
 
110
  outputs = model.generate(
111
+ **inputs,
112
+ max_new_tokens=Config.MAX_NEW_TOKENS,
113
+ temperature=Config.TEMPERATURE,
114
+ top_p=Config.TOP_P,
115
  do_sample=True,
116
+ pad_token_id=tokenizer.eos_token_id,
117
+ eos_token_id=tokenizer.eos_token_id
 
 
 
118
  )
119
 
120
+ # Decode response
121
+ full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
122
 
123
+ # Extract generated text (remove input prompt)
124
+ generated_text = full_response[len(user_input):].strip()
125
 
126
+ # Format the output
127
+ formatted_output = "## πŸ“Š Life Satisfaction Analysis\n\n"
128
+ formatted_output += generated_text
 
129
 
130
+ return formatted_output
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  except Exception as e:
133
+ return f"❌ Error during analysis: {str(e)}"
134
 
135
+ # Example prompts for users
136
+ EXAMPLE_PROMPTS = [
137
+ "I'm a 29-year-old professional feeling burned out at work. My health is okay but I rarely exercise. Financially stable but not saving much. Great relationship with my partner. What's my life satisfaction score?",
138
+ "Rate my life satisfaction: Work is stressful (3/10), health is good (7/10), finances are tight (4/10), relationships are excellent (9/10). Give me a comprehensive analysis.",
139
+ "Analyze my satisfaction: Career going well, making good money, but no time for friends or hobbies. Always tired and stressed. How can I improve?",
140
+ "I'm happy with my job and relationships but struggling with debt and health issues. Need advice on balancing everything.",
141
+ "Just graduated, starting my career, living paycheck to paycheck, single but happy, very healthy and active. Analyze my life satisfaction."
142
+ ]
143
 
144
+ # Gradio Interface
145
+ def create_interface():
146
+ """Create the Gradio interface."""
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
+ with gr.Blocks(title="Life Satisfaction Analysis", theme=gr.themes.Soft()) as demo:
149
+ gr.Markdown(
150
+ """
151
+ # 🌟 AI Life Satisfaction Analyzer
152
+
153
+ This AI-powered tool analyzes your life satisfaction based on your description of your current situation.
154
+ Simply describe your life circumstances, challenges, and satisfaction levels across different areas.
155
+
156
+ **The AI will analyze:**
157
+ - Overall life satisfaction score
158
+ - Balance across life domains (work, health, finances, relationships)
159
+ - Personalized recommendations for improvement
160
+ - Action plans and strategies
161
+ """
162
  )
163
+
164
+ with gr.Row():
165
+ with gr.Column():
166
+ # Input section
167
+ input_text = gr.Textbox(
168
+ label="πŸ“ Describe Your Current Life Situation",
169
+ placeholder="Tell me about your work, health, finances, relationships, and any other aspects of your life you'd like analyzed. You can include satisfaction ratings (1-10) or just describe how you feel about each area.",
170
+ lines=8,
171
+ max_lines=15
172
+ )
173
+
174
+ with gr.Row():
175
+ analyze_btn = gr.Button("πŸ” Analyze My Life Satisfaction", variant="primary", scale=2)
176
+ clear_btn = gr.Button("πŸ—‘οΈ Clear", scale=1)
177
+
178
+ # Examples section
179
+ gr.Markdown("### πŸ’‘ Example Inputs")
180
+ example_dropdown = gr.Dropdown(
181
+ choices=EXAMPLE_PROMPTS,
182
+ label="Select an example to try:",
183
+ interactive=True
184
+ )
185
+
186
+ with gr.Row():
187
+ with gr.Column():
188
+ # Output section
189
+ output = gr.Markdown(label="Analysis Results")
190
+
191
+ # Event handlers
192
+ analyze_btn.click(
193
+ fn=analyze_satisfaction,
194
+ inputs=input_text,
195
+ outputs=output
196
+ )
197
+
198
+ clear_btn.click(
199
+ fn=lambda: ("", ""),
200
+ inputs=[],
201
+ outputs=[input_text, output]
202
+ )
203
+
204
+ example_dropdown.change(
205
+ fn=lambda x: x,
206
+ inputs=example_dropdown,
207
+ outputs=input_text
208
+ )
209
+
210
+ # Tips section
211
+ with gr.Accordion("πŸ“– Tips for Best Results", open=False):
212
+ gr.Markdown(
213
+ """
214
+ **How to get the most accurate analysis:**
215
+
216
+ 1. **Be specific** about your situation in each life area
217
+ 2. **Include ratings** (1-10) if you want quantified analysis
218
+ 3. **Mention your age** and life stage for context
219
+ 4. **Describe challenges** you're facing
220
+ 5. **Share your goals** or what you'd like to improve
221
+
222
+ **Example format:**
223
+ - Work: [Your situation and satisfaction level]
224
+ - Health: [Physical and mental wellness status]
225
+ - Finances: [Financial situation and concerns]
226
+ - Relationships: [Social and romantic relationships]
227
+ - Personal: [Hobbies, growth, fulfillment]
228
+ """
229
+ )
230
+
231
+ # Footer
232
+ gr.Markdown(
233
+ """
234
+ ---
235
+ πŸ’‘ **Disclaimer:** This AI tool provides general insights based on the information you provide.
236
+ For professional advice, please consult qualified experts in relevant fields.
237
+
238
+ πŸ”’ **Privacy:** Your input is processed in real-time and not stored.
239
+ """
240
  )
241
 
242
+ return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
 
244
+ # Launch the app
245
  if __name__ == "__main__":
246
+ # Load model on startup
247
+ print("πŸš€ Starting Life Satisfaction Analysis Tool...")
248
+ load_model()
249
+
250
+ # Create and launch interface
251
+ demo = create_interface()
252
  demo.launch()