eesfeg commited on
Commit
2febca8
·
1 Parent(s): 75e6b29
Files changed (1) hide show
  1. app.py +116 -139
app.py CHANGED
@@ -2,197 +2,174 @@
2
 
3
  import os
4
  import sys
5
- import asyncio
6
  import warnings
7
- import signal
8
  import torch
9
- from transformers import AutoTokenizer, AutoModelForCausalLM
10
  import gradio as gr
11
 
12
- # =================== ASYNCIO FIX ===================
13
- # Fix for the asyncio cleanup error
14
- if sys.version_info >= (3, 8) and sys.platform.startswith('linux'):
15
- # This prevents the error on Linux with Python 3.8+
16
- try:
17
- import uvloop
18
- uvloop.install()
19
- except ImportError:
20
- pass
21
 
22
- # Suppress warnings
23
- warnings.filterwarnings("ignore")
24
- os.environ["PYTHONWARNINGS"] = "ignore"
 
 
 
 
 
 
 
 
25
 
26
  # =================== MODEL LOADING ===================
27
  @gr.cache_resource
28
  def load_model():
29
- """Load the TinyLlama model"""
30
- print("🚀 Loading Mistral_Test model...")
31
- MODEL_ID = "abdelac/Mistral_Test"
32
 
 
33
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  model = AutoModelForCausalLM.from_pretrained(
35
  MODEL_ID,
36
- torch_dtype=torch.float32,
37
- device_map="cpu",
38
- low_cpu_mem_usage=True,
39
- offload_folder="offload"
40
  )
41
 
 
 
 
 
42
  print("✅ Model loaded successfully!")
43
  return tokenizer, model
44
 
45
- # Load model once
46
- tokenizer, model = load_model()
47
-
48
- # =================== GENERATION FUNCTION ===================
49
- def generate_text(prompt, max_tokens=150, temperature=0.7):
50
- """Generate text based on prompt"""
51
  try:
52
- # Tokenize
53
- inputs = tokenizer(prompt, return_tensors="pt")
 
 
 
 
 
 
 
54
 
55
- # Generate
56
  with torch.no_grad():
57
  outputs = model.generate(
58
  **inputs,
59
- max_new_tokens=max_tokens,
60
  temperature=temperature,
61
  do_sample=True,
62
- pad_token_id=tokenizer.eos_token_id
 
 
 
63
  )
64
 
65
  # Decode
66
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
67
  return result
68
 
 
 
69
  except Exception as e:
70
  return f"❌ Error: {str(e)}"
71
 
72
- # =================== GRADIO INTERFACE ===================
73
  def create_interface():
74
- """Create the Gradio interface"""
75
  with gr.Blocks(
76
- title="🦙 Mistral_Test Demo",
77
- theme=gr.themes.Soft(),
78
- css=".gradio-container {max-width: 800px !important}"
79
  ) as demo:
80
 
81
- gr.Markdown("""
82
- # 🦙 TinyLlama Text Generator
83
 
84
- Generate text using the TinyLlama model (1.1B parameters)
 
85
 
86
- **Model**: [abdelac/tinyllama](https://huggingface.co/abdelac/tinyllama)
87
  """)
88
 
89
  with gr.Row():
90
- with gr.Column(scale=2):
91
- prompt = gr.Textbox(
92
- label="📝 Input Prompt",
93
- placeholder="Type your text here...",
94
- lines=5,
95
- value="Once upon a time in a magical forest,"
96
- )
97
-
98
- with gr.Row():
99
- max_tokens = gr.Slider(
100
- 50, 500, value=150,
101
- label="📏 Max Tokens",
102
- info="Maximum length of generated text"
103
- )
104
- temperature = gr.Slider(
105
- 0.1, 2.0, value=0.7,
106
- label="🌡️ Temperature",
107
- info="Higher = more creative, Lower = more focused"
108
- )
109
-
110
- with gr.Row():
111
- generate_btn = gr.Button(
112
- "✨ Generate",
113
- variant="primary",
114
- size="lg"
115
- )
116
- clear_btn = gr.Button(
117
- "🗑️ Clear",
118
- variant="secondary"
119
- )
120
-
121
- with gr.Column(scale=3):
122
- output = gr.Textbox(
123
- label="📄 Generated Text",
124
- lines=12,
125
- interactive=False
126
- )
127
 
128
- # Examples
129
- gr.Examples(
130
- examples=[
131
- ["Write a short story about a robot learning to paint"],
132
- ["Explain quantum computing in simple terms"],
133
- ["Python function to calculate fibonacci sequence:"],
134
- ["The benefits of renewable energy include"],
135
- ["Write a poem about artificial intelligence"]
136
- ],
137
- inputs=prompt,
138
- label="💡 Try these examples"
139
- )
140
 
141
- # Functions
142
- generate_btn.click(
143
- fn=generate_text,
144
- inputs=[prompt, max_tokens, temperature],
145
- outputs=output,
146
- api_name="generate"
147
- )
148
 
149
- clear_btn.click(
150
- fn=lambda: ("", ""),
151
- inputs=[],
152
- outputs=[prompt, output]
153
  )
154
 
155
- # Status
156
- gr.Markdown("---")
157
  gr.Markdown("""
158
- <div style='text-align: center; color: #666; font-size: 0.9em;'>
159
- Model loaded successfully | 🚀 Ready to generate text
160
- </div>
 
 
161
  """)
 
 
 
 
 
 
 
162
 
163
  return demo
164
 
165
- # =================== MAIN ENTRY POINT ===================
166
- def main():
167
- """Main function with proper cleanup"""
168
- demo = create_interface()
169
-
170
- # Clean launch configuration
171
- try:
172
- demo.launch(
173
- server_name="0.0.0.0",
174
- server_port=7860,
175
- share=False,
176
- quiet=True, # Reduce console output
177
- debug=False, # Disable debug mode
178
- show_error=True, # Show errors in UI
179
- favicon_path=None,
180
- ssl_verify=True,
181
- max_file_size="2MB",
182
- allowed_paths=["./"],
183
- blocked_paths=[]
184
- )
185
- except KeyboardInterrupt:
186
- print("\n👋 Shutting down gracefully...")
187
- sys.exit(0)
188
- except Exception as e:
189
- print(f"❌ Error: {e}")
190
- sys.exit(1)
191
-
192
  if __name__ == "__main__":
193
- # Set up signal handlers for clean shutdown
194
- signal.signal(signal.SIGINT, lambda s, f: sys.exit(0))
195
- signal.signal(signal.SIGTERM, lambda s, f: sys.exit(0))
196
 
197
- # Run the app
198
- main()
 
 
 
 
 
 
 
 
 
2
 
3
  import os
4
  import sys
 
5
  import warnings
 
6
  import torch
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
8
  import gradio as gr
9
 
10
+ # =================== CONFIGURATION ===================
11
+ MODEL_ID = "abdelac/Mistral_Test"
12
+ USE_QUANTIZATION = True # MUST be True for 16GB RAM
 
 
 
 
 
 
13
 
14
+ # =================== QUANTIZATION SETUP ===================
15
+ if USE_QUANTIZATION:
16
+ bnb_config = BitsAndBytesConfig(
17
+ load_in_4bit=True, # Critical for memory
18
+ bnb_4bit_quant_type="nf4", # 4-bit quantization
19
+ bnb_4bit_compute_dtype=torch.float16, # Compute in float16
20
+ bnb_4bit_use_double_quant=True, # Extra memory savings
21
+ llm_int8_enable_fp32_cpu_offload=True # Offload to CPU if needed
22
+ )
23
+ else:
24
+ bnb_config = None
25
 
26
  # =================== MODEL LOADING ===================
27
  @gr.cache_resource
28
  def load_model():
29
+ """Load Mistral model with quantization"""
30
+ print(f"🚀 Loading {MODEL_ID}...")
 
31
 
32
+ # Load tokenizer
33
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
34
+
35
+ # Configure model loading based on quantization
36
+ load_kwargs = {
37
+ "torch_dtype": torch.float16,
38
+ "device_map": "auto",
39
+ "low_cpu_mem_usage": True,
40
+ }
41
+
42
+ if USE_QUANTIZATION:
43
+ load_kwargs["quantization_config"] = bnb_config
44
+ print("✅ Using 4-bit quantization (~4GB RAM)")
45
+ else:
46
+ load_kwargs["device_map"] = "cpu"
47
+ print("⚠️ Using CPU only (slow but safe)")
48
+
49
+ # Load model
50
  model = AutoModelForCausalLM.from_pretrained(
51
  MODEL_ID,
52
+ **load_kwargs
 
 
 
53
  )
54
 
55
+ # Set padding token if not present
56
+ if tokenizer.pad_token is None:
57
+ tokenizer.pad_token = tokenizer.eos_token
58
+
59
  print("✅ Model loaded successfully!")
60
  return tokenizer, model
61
 
62
+ # =================== MEMORY-EFFICIENT GENERATION ===================
63
+ def generate_text(prompt, max_tokens=100, temperature=0.7):
64
+ """Generate text with memory constraints"""
 
 
 
65
  try:
66
+ tokenizer, model = load_model()
67
+
68
+ # Tokenize with truncation
69
+ inputs = tokenizer(
70
+ prompt,
71
+ return_tensors="pt",
72
+ truncation=True,
73
+ max_length=512
74
+ ).to(model.device)
75
 
76
+ # Generate with conservative settings
77
  with torch.no_grad():
78
  outputs = model.generate(
79
  **inputs,
80
+ max_new_tokens=min(max_tokens, 150), # Cap at 150
81
  temperature=temperature,
82
  do_sample=True,
83
+ pad_token_id=tokenizer.eos_token_id,
84
+ repetition_penalty=1.1, # Prevent repetition
85
+ no_repeat_ngram_size=2,
86
+ early_stopping=True
87
  )
88
 
89
  # Decode
90
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
91
  return result
92
 
93
+ except torch.cuda.OutOfMemoryError:
94
+ return "❌ Out of memory! Try reducing max tokens or using CPU mode."
95
  except Exception as e:
96
  return f"❌ Error: {str(e)}"
97
 
98
+ # =================== SIMPLIFIED INTERFACE ===================
99
  def create_interface():
100
+ """Create memory-aware interface"""
101
  with gr.Blocks(
102
+ title="🦅 Mistral Test Demo",
103
+ theme=gr.themes.Soft()
 
104
  ) as demo:
105
 
106
+ gr.Markdown(f"""
107
+ # 🦅 Mistral Test Demo
108
 
109
+ **Model:** [{MODEL_ID}](https://huggingface.co/{MODEL_ID})
110
+ **Mode:** {'4-bit Quantized' if USE_QUANTIZATION else 'CPU'}
111
 
112
+ ⚠️ **Note:** Mistral 7B requires quantization to run in 16GB RAM
113
  """)
114
 
115
  with gr.Row():
116
+ prompt = gr.Textbox(
117
+ label="Prompt",
118
+ placeholder="Enter your text...",
119
+ lines=3,
120
+ value="What is artificial intelligence?"
121
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
+ with gr.Row():
124
+ max_tokens = gr.Slider(
125
+ 30, 150, value=80, # Reduced max for memory
126
+ label="Max Tokens",
127
+ info="Higher values use more memory"
128
+ )
129
+ temperature = gr.Slider(
130
+ 0.1, 1.0, value=0.7,
131
+ label="Temperature"
132
+ )
 
 
133
 
134
+ generate_btn = gr.Button("Generate", variant="primary", size="lg")
 
 
 
 
 
 
135
 
136
+ output = gr.Textbox(
137
+ label="Generated Text",
138
+ lines=8,
139
+ show_copy_button=True
140
  )
141
 
142
+ # Memory warning
 
143
  gr.Markdown("""
144
+ ### 💡 Memory Optimization Tips:
145
+ 1. **Max Tokens 100** for best results
146
+ 2. **Temperature ~0.7** for balanced output
147
+ 3. If OOM occurs, refresh the page
148
+ 4. Close other tabs/applications
149
  """)
150
+
151
+ # Connect button
152
+ generate_btn.click(
153
+ fn=generate_text,
154
+ inputs=[prompt, max_tokens, temperature],
155
+ outputs=output
156
+ )
157
 
158
  return demo
159
 
160
+ # =================== MAIN ===================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  if __name__ == "__main__":
162
+ # Suppress warnings
163
+ warnings.filterwarnings("ignore")
164
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
165
 
166
+ # Create and launch
167
+ demo = create_interface()
168
+ demo.launch(
169
+ server_name="0.0.0.0",
170
+ server_port=7860,
171
+ share=False,
172
+ quiet=True,
173
+ debug=False,
174
+ show_error=True
175
+ )