eesfeg commited on
Commit
9cb84f0
Β·
1 Parent(s): 252f73e
Files changed (1) hide show
  1. app.py +187 -36
app.py CHANGED
@@ -1,47 +1,198 @@
1
- import gradio as gr
 
 
 
 
 
 
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
4
 
5
- MODEL_ID = "abdelac/Mistral_Test" # your model repo
 
 
 
 
 
 
 
 
6
 
7
- print("πŸ” Loading tokenizer...")
8
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
 
9
 
10
- print("πŸ” Loading model...")
11
- model = AutoModelForCausalLM.from_pretrained(
12
- MODEL_ID,
13
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
14
- device_map="auto"
15
- )
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- model.eval()
18
- print("βœ… Model loaded")
19
 
20
- # ---------- Inference function ----------
21
- def generate(prompt):
22
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- with torch.no_grad():
25
- output_ids = model.generate(
26
- **inputs,
27
- max_new_tokens=128,
28
- temperature=0.7,
29
- do_sample=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- return tokenizer.decode(output_ids[0], skip_special_tokens=True)
33
-
34
- # ---------- Gradio UI ----------
35
- demo = gr.Interface(
36
- fn=generate,
37
- inputs=gr.Textbox(lines=4, label="Prompt"),
38
- outputs=gr.Textbox(lines=8, label="Output"),
39
- title="Mistral Test – Space Inference"
40
- )
41
-
42
- demo.launch(
43
- server_name="0.0.0.0",
44
- server_port=7860,
45
- ssr_mode=False,
46
- )
 
 
 
 
 
 
 
 
 
 
 
 
47
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import os
4
+ import sys
5
+ import asyncio
6
+ import warnings
7
+ import signal
8
  import torch
9
  from transformers import AutoTokenizer, AutoModelForCausalLM
10
+ import gradio as gr
11
 
12
+ # =================== ASYNCIO FIX ===================
13
+ # Fix for the asyncio cleanup error
14
+ if sys.version_info >= (3, 8) and sys.platform.startswith('linux'):
15
+ # This prevents the error on Linux with Python 3.8+
16
+ try:
17
+ import uvloop
18
+ uvloop.install()
19
+ except ImportError:
20
+ pass
21
 
22
+ # Suppress warnings
23
+ warnings.filterwarnings("ignore")
24
+ os.environ["PYTHONWARNINGS"] = "ignore"
25
 
26
+ # =================== MODEL LOADING ===================
27
+ @gr.cache_resource
28
+ def load_model():
29
+ """Load the TinyLlama model"""
30
+ print("πŸš€ Loading TinyLlama model...")
31
+ MODEL_ID = "abdelac/tinyllama"
32
+
33
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
34
+ model = AutoModelForCausalLM.from_pretrained(
35
+ MODEL_ID,
36
+ torch_dtype=torch.float32,
37
+ device_map="auto",
38
+ low_cpu_mem_usage=True
39
+ )
40
+
41
+ print("βœ… Model loaded successfully!")
42
+ return tokenizer, model
43
 
44
+ # Load model once
45
+ tokenizer, model = load_model()
46
 
47
+ # =================== GENERATION FUNCTION ===================
48
+ def generate_text(prompt, max_tokens=150, temperature=0.7):
49
+ """Generate text based on prompt"""
50
+ try:
51
+ # Tokenize
52
+ inputs = tokenizer(prompt, return_tensors="pt")
53
+
54
+ # Generate
55
+ with torch.no_grad():
56
+ outputs = model.generate(
57
+ **inputs,
58
+ max_new_tokens=max_tokens,
59
+ temperature=temperature,
60
+ do_sample=True,
61
+ pad_token_id=tokenizer.eos_token_id
62
+ )
63
+
64
+ # Decode
65
+ result = tokenizer.decode(outputs[0], skip_special_tokens=True)
66
+ return result
67
+
68
+ except Exception as e:
69
+ return f"❌ Error: {str(e)}"
70
 
71
+ # =================== GRADIO INTERFACE ===================
72
+ def create_interface():
73
+ """Create the Gradio interface"""
74
+ with gr.Blocks(
75
+ title="πŸ¦™ TinyLlama Demo",
76
+ theme=gr.themes.Soft(),
77
+ css=".gradio-container {max-width: 800px !important}"
78
+ ) as demo:
79
+
80
+ gr.Markdown("""
81
+ # πŸ¦™ TinyLlama Text Generator
82
+
83
+ Generate text using the TinyLlama model (1.1B parameters)
84
+
85
+ **Model**: [abdelac/tinyllama](https://huggingface.co/abdelac/tinyllama)
86
+ """)
87
+
88
+ with gr.Row():
89
+ with gr.Column(scale=2):
90
+ prompt = gr.Textbox(
91
+ label="πŸ“ Input Prompt",
92
+ placeholder="Type your text here...",
93
+ lines=5,
94
+ value="Once upon a time in a magical forest,"
95
+ )
96
+
97
+ with gr.Row():
98
+ max_tokens = gr.Slider(
99
+ 50, 500, value=150,
100
+ label="πŸ“ Max Tokens",
101
+ info="Maximum length of generated text"
102
+ )
103
+ temperature = gr.Slider(
104
+ 0.1, 2.0, value=0.7,
105
+ label="🌑️ Temperature",
106
+ info="Higher = more creative, Lower = more focused"
107
+ )
108
+
109
+ with gr.Row():
110
+ generate_btn = gr.Button(
111
+ "✨ Generate",
112
+ variant="primary",
113
+ size="lg"
114
+ )
115
+ clear_btn = gr.Button(
116
+ "πŸ—‘οΈ Clear",
117
+ variant="secondary"
118
+ )
119
+
120
+ with gr.Column(scale=3):
121
+ output = gr.Textbox(
122
+ label="πŸ“„ Generated Text",
123
+ lines=12,
124
+ interactive=False
125
+ )
126
+
127
+ # Examples
128
+ gr.Examples(
129
+ examples=[
130
+ ["Write a short story about a robot learning to paint"],
131
+ ["Explain quantum computing in simple terms"],
132
+ ["Python function to calculate fibonacci sequence:"],
133
+ ["The benefits of renewable energy include"],
134
+ ["Write a poem about artificial intelligence"]
135
+ ],
136
+ inputs=prompt,
137
+ label="πŸ’‘ Try these examples"
138
  )
139
+
140
+ # Functions
141
+ generate_btn.click(
142
+ fn=generate_text,
143
+ inputs=[prompt, max_tokens, temperature],
144
+ outputs=output,
145
+ api_name="generate"
146
+ )
147
+
148
+ clear_btn.click(
149
+ fn=lambda: ("", ""),
150
+ inputs=[],
151
+ outputs=[prompt, output]
152
+ )
153
+
154
+ # Status
155
+ gr.Markdown("---")
156
+ gr.Markdown("""
157
+ <div style='text-align: center; color: #666; font-size: 0.9em;'>
158
+ βœ… Model loaded successfully | πŸš€ Ready to generate text
159
+ </div>
160
+ """)
161
+
162
+ return demo
163
 
164
+ # =================== MAIN ENTRY POINT ===================
165
+ def main():
166
+ """Main function with proper cleanup"""
167
+ demo = create_interface()
168
+
169
+ # Clean launch configuration
170
+ try:
171
+ demo.launch(
172
+ server_name="0.0.0.0",
173
+ server_port=7860,
174
+ share=False,
175
+ quiet=True, # Reduce console output
176
+ debug=False, # Disable debug mode
177
+ show_error=True, # Show errors in UI
178
+ favicon_path=None,
179
+ ssl_verify=True,
180
+ max_file_size="2MB",
181
+ allowed_paths=["./"],
182
+ blocked_paths=[],
183
+ show_api=True
184
+ )
185
+ except KeyboardInterrupt:
186
+ print("\nπŸ‘‹ Shutting down gracefully...")
187
+ sys.exit(0)
188
+ except Exception as e:
189
+ print(f"❌ Error: {e}")
190
+ sys.exit(1)
191
 
192
+ if __name__ == "__main__":
193
+ # Set up signal handlers for clean shutdown
194
+ signal.signal(signal.SIGINT, lambda s, f: sys.exit(0))
195
+ signal.signal(signal.SIGTERM, lambda s, f: sys.exit(0))
196
+
197
+ # Run the app
198
+ main()