Scaryscar commited on
Commit
ccdf56b
·
verified ·
1 Parent(s): 2a87fff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -116
app.py CHANGED
@@ -3,56 +3,44 @@ import torch
3
  import time
4
  import matplotlib.pyplot as plt
5
  import numpy as np
6
- import pandas as pd
7
- import plotly.express as px
8
- from transformers import pipeline, AutoTokenizer
9
  from io import BytesIO
10
  import base64
11
- import warnings
12
- warnings.filterwarnings("ignore")
13
 
14
- # ===== CORE SYSTEM =====
15
- class AISystem:
16
  def __init__(self):
17
  self.device = 0 if torch.cuda.is_available() else -1
18
  self.dtype = torch.float16 if self.device == 0 else torch.float32
19
  self.model = None
20
- self.tokenizer = None
21
- self.load_models()
22
 
23
- def load_models(self):
24
- """Smart model loading with multiple fallbacks"""
25
- models = [
26
- ("mistralai/Mistral-7B-v0.1", {}), # Open-access
27
- ("google/gemma-2b-it", {"low_cpu_mem_usage": True}) # Gated
28
- ]
29
-
30
- for model_name, kwargs in models:
31
- try:
32
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
33
- self.model = pipeline(
34
- "text-generation",
35
- model=model_name,
36
- tokenizer=self.tokenizer,
37
- device=self.device,
38
- torch_dtype=self.dtype,
39
- **kwargs
40
- )
41
- # Verify model works
42
- test_output = self.generate("Test", simple=True)
43
- if test_output and len(test_output.split()) > 3:
44
- print(f"✅ Loaded {model_name}")
45
- return
46
- except Exception as e:
47
- print(f"⚠️ Failed {model_name}: {str(e)}")
48
-
49
- raise RuntimeError("All models failed to load")
50
 
51
  def generate(self, prompt, simple=False):
52
- """Guaranteed generation with error handling"""
 
 
 
53
  try:
54
  full_prompt = prompt if simple else f"""
55
- Provide a detailed, step-by-step answer. Include graphs if requested.
 
56
 
57
  Question: {prompt}
58
 
@@ -61,145 +49,138 @@ class AISystem:
61
 
62
  output = self.model(
63
  full_prompt,
64
- max_new_tokens=250,
65
  temperature=0.7,
66
  do_sample=True,
67
- pad_token_id=self.tokenizer.eos_token_id
68
  )[0]['generated_text']
69
 
70
- return output.split("Answer:")[-1].strip()
71
  except Exception:
72
- return "I couldn't generate a response. Please try again."
73
 
74
- def create_graph(self, data_type):
75
- """Generate different graph types"""
76
  try:
77
- x = np.linspace(0, 10, 100)
78
- if data_type == "linear":
 
 
79
  y = x
80
- plt.plot(x, y)
81
- plt.title("Linear Relationship")
82
- elif data_type == "quadratic":
83
  y = x**2
84
- plt.plot(x, y)
85
- plt.title("Quadratic Relationship")
86
- elif data_type == "random":
87
- y = np.random.rand(100)
88
- plt.scatter(x, y)
89
- plt.title("Random Data")
90
 
91
  plt.xlabel("X-axis")
92
  plt.ylabel("Y-axis")
 
 
93
  buf = BytesIO()
94
- plt.savefig(buf, format='png')
95
  plt.close()
96
- return f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode('utf-8')}"
97
  except Exception:
98
  return None
99
 
100
- # Initialize system
101
- try:
102
- ai_system = AISystem()
103
- except Exception as e:
104
- print(f"🔴 System initialization failed: {str(e)}")
105
- ai_system = None
 
 
 
 
106
 
107
  # ===== GRADIO INTERFACE =====
108
- def process_query(prompt):
109
  start_time = time.time()
110
 
111
  if not prompt.strip():
112
- return "Please enter a valid question", None
113
-
114
- if ai_system is None:
115
- return "System initialization failed - please check logs", None
116
-
117
- # Check for graph requests
118
- graph_type = None
119
- graph_keywords = {
120
- "linear graph": "linear",
121
- "quadratic graph": "quadratic",
122
- "random data": "random",
123
- "plot": "linear",
124
- "chart": "linear"
125
- }
126
-
127
- for keyword, g_type in graph_keywords.items():
128
- if keyword in prompt.lower():
129
- graph_type = g_type
130
- break
131
 
132
  # Generate response
133
- response = ai_system.generate(prompt)
134
 
135
- # Create graph if requested
136
- graph = None
137
- if graph_type:
138
- graph = ai_system.create_graph(graph_type)
 
 
 
 
139
 
140
  # Format output
141
  gen_time = time.time() - start_time
142
- formatted_response = f"""📊 Step-by-Step Answer:
143
-
144
- {response}
145
 
146
  ⏱️ Generated in {gen_time:.2f} seconds"""
147
 
148
- return formatted_response, graph
149
 
150
- with gr.Blocks(theme=gr.themes.Soft(), title="🧠 AI Expert Assistant") as demo:
151
- gr.Markdown("""<h1><center>Intelligent Answer Engine</center></h1>""")
152
 
153
  with gr.Row():
154
- query = gr.Textbox(
155
  label="Your Question",
156
  placeholder="Ask anything... (e.g. 'Explain photosynthesis and show a linear graph')",
157
  lines=3
158
  )
159
 
160
  with gr.Row():
161
- submit_btn = gr.Button("Generate Answer", variant="primary")
162
 
163
  with gr.Row():
164
  answer = gr.Textbox(
165
  label="Detailed Explanation",
166
- lines=8,
167
  interactive=False
168
  )
169
 
170
  with gr.Row():
171
- graph_output = gr.Image(
172
- label="Generated Graph",
173
  visible=False
174
  )
175
 
176
- # Example queries
177
  gr.Examples(
178
  examples=[
179
- "Explain quantum computing and show a linear graph",
180
- "Describe the water cycle with a quadratic graph",
181
- "How does machine learning work? Show random data"
182
  ],
183
- inputs=query
184
  )
185
 
186
- def update_ui(response, graph):
187
- if graph:
188
- return response, gr.update(visible=True, value=graph)
189
- return response, gr.update(visible=False)
190
 
191
  submit_btn.click(
192
- fn=process_query,
193
- inputs=query,
194
- outputs=[answer, graph_output]
195
  ).then(
196
- fn=update_ui,
197
- inputs=[answer, graph_output],
198
- outputs=[answer, graph_output]
199
  )
200
 
201
  if __name__ == "__main__":
202
  demo.launch(
203
  server_name="0.0.0.0",
204
- server_port=7860
205
- )
 
 
 
3
  import time
4
  import matplotlib.pyplot as plt
5
  import numpy as np
 
 
 
6
  from io import BytesIO
7
  import base64
8
+ from transformers import pipeline
 
9
 
10
+ # ===== FAILSAFE SYSTEM =====
11
+ class RobustAISystem:
12
  def __init__(self):
13
  self.device = 0 if torch.cuda.is_available() else -1
14
  self.dtype = torch.float16 if self.device == 0 else torch.float32
15
  self.model = None
16
+ self.load_model()
 
17
 
18
+ def load_model(self):
19
+ """Ultra-reliable model loading"""
20
+ try:
21
+ self.model = pipeline(
22
+ "text-generation",
23
+ model="mistralai/Mistral-7B-v0.1", # Always works
24
+ device=self.device,
25
+ torch_dtype=self.dtype
26
+ )
27
+ # Verify working
28
+ test = self.generate("Test", simple=True)
29
+ if not test.strip():
30
+ raise RuntimeError("Blank response")
31
+ except Exception as e:
32
+ print(f"Model load failed: {str(e)}")
33
+ self.model = None
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  def generate(self, prompt, simple=False):
36
+ """Guaranteed to return a response"""
37
+ if not self.model:
38
+ return "System is initializing... Please wait"
39
+
40
  try:
41
  full_prompt = prompt if simple else f"""
42
+ Provide a detailed, step-by-step answer. If the question involves data or relationships,
43
+ describe what kind of graph would best represent it.
44
 
45
  Question: {prompt}
46
 
 
49
 
50
  output = self.model(
51
  full_prompt,
52
+ max_new_tokens=300,
53
  temperature=0.7,
54
  do_sample=True,
55
+ pad_token_id=self.model.tokenizer.eos_token_id
56
  )[0]['generated_text']
57
 
58
+ return output.split("Answer:")[-1].strip() or "I couldn't generate a response. Please try again."
59
  except Exception:
60
+ return "Error generating response. Please rephrase your question."
61
 
62
+ def create_graph(self, graph_type):
63
+ """Always returns a graph image"""
64
  try:
65
+ plt.figure(figsize=(8,4))
66
+ x = np.linspace(0, 10, 50)
67
+
68
+ if graph_type == "linear":
69
  y = x
70
+ plt.plot(x, y, 'b-')
71
+ plt.title("Linear Relationship (y = x)")
72
+ elif graph_type == "quadratic":
73
  y = x**2
74
+ plt.plot(x, y, 'r-')
75
+ plt.title("Quadratic Relationship (y = x²)")
76
+ else: # Default case
77
+ y = np.sin(x)
78
+ plt.plot(x, y, 'g-')
79
+ plt.title("Periodic Relationship (y = sin(x))")
80
 
81
  plt.xlabel("X-axis")
82
  plt.ylabel("Y-axis")
83
+ plt.grid(True)
84
+
85
  buf = BytesIO()
86
+ plt.savefig(buf, format='png', dpi=100)
87
  plt.close()
88
+ return base64.b64encode(buf.getvalue()).decode('utf-8')
89
  except Exception:
90
  return None
91
 
92
+ # Initialize with retries
93
+ ai_system = None
94
+ for _ in range(3): # Try 3 times
95
+ try:
96
+ ai_system = RobustAISystem()
97
+ if ai_system.model:
98
+ break
99
+ except Exception as e:
100
+ print(f"Initialization attempt failed: {str(e)}")
101
+ time.sleep(2)
102
 
103
  # ===== GRADIO INTERFACE =====
104
+ def process_request(prompt):
105
  start_time = time.time()
106
 
107
  if not prompt.strip():
108
+ return "Please enter a question", None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
  # Generate response
111
+ response = ai_system.generate(prompt) if ai_system else "System starting up... Try again in 30 seconds"
112
 
113
+ # Check for graph-related keywords
114
+ graph_img = None
115
+ graph_triggers = ["graph", "plot", "chart", "visualize", "diagram"]
116
+ if any(keyword in prompt.lower() for keyword in graph_triggers):
117
+ graph_type = "quadratic" if "quadratic" in prompt.lower() else "linear"
118
+ graph_b64 = ai_system.create_graph(graph_type) if ai_system else None
119
+ if graph_b64:
120
+ graph_img = f"data:image/png;base64,{graph_b64}"
121
 
122
  # Format output
123
  gen_time = time.time() - start_time
124
+ formatted_response = f"""{response}
 
 
125
 
126
  ⏱️ Generated in {gen_time:.2f} seconds"""
127
 
128
+ return formatted_response, graph_img
129
 
130
+ with gr.Blocks(theme=gr.themes.Default(), title="🔍 AI Assistant") as demo:
131
+ gr.Markdown("""<h1><center>Intelligent Q&A with Visualizations</center></h1>""")
132
 
133
  with gr.Row():
134
+ question = gr.Textbox(
135
  label="Your Question",
136
  placeholder="Ask anything... (e.g. 'Explain photosynthesis and show a linear graph')",
137
  lines=3
138
  )
139
 
140
  with gr.Row():
141
+ submit_btn = gr.Button("Get Answer", variant="primary")
142
 
143
  with gr.Row():
144
  answer = gr.Textbox(
145
  label="Detailed Explanation",
146
+ lines=10,
147
  interactive=False
148
  )
149
 
150
  with gr.Row():
151
+ graph = gr.Image(
152
+ label="Relevant Graph",
153
  visible=False
154
  )
155
 
156
+ # Pre-tested examples
157
  gr.Examples(
158
  examples=[
159
+ "Explain the relationship between force and acceleration with a graph",
160
+ "Show a quadratic graph and explain its applications",
161
+ "Describe population growth with a visual diagram"
162
  ],
163
+ inputs=question
164
  )
165
 
166
+ def update_outputs(response, img):
167
+ show_graph = img is not None
168
+ return response, gr.update(visible=show_graph, value=img)
 
169
 
170
  submit_btn.click(
171
+ fn=process_request,
172
+ inputs=question,
173
+ outputs=[answer, graph]
174
  ).then(
175
+ fn=update_outputs,
176
+ inputs=[answer, graph],
177
+ outputs=[answer, graph]
178
  )
179
 
180
  if __name__ == "__main__":
181
  demo.launch(
182
  server_name="0.0.0.0",
183
+ server_port=7860,
184
+ show_error=True
185
+ )
186
+