ZENLLC commited on
Commit
e1ae3d1
Β·
verified Β·
1 Parent(s): fd2b1fa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -37
app.py CHANGED
@@ -7,19 +7,10 @@ from transformers import pipeline
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
8
 
9
  # ----------------------------
10
- # Load text model (fallbacks)
11
  # ----------------------------
12
- if device == "cuda":
13
- text_model_name = "HuggingFaceH4/zephyr-7b-beta"
14
- else:
15
- text_model_name = "google/flan-t5-base" # CPU-friendly
16
-
17
- chat_model = pipeline(
18
- "text-generation",
19
- model=text_model_name,
20
- device=0 if device=="cuda" else -1,
21
- return_full_text=False
22
- )
23
 
24
  # ----------------------------
25
  # Try to load Stable Diffusion (only if GPU)
@@ -29,24 +20,21 @@ if device == "cuda":
29
  try:
30
  from diffusers import StableDiffusionPipeline
31
  sd_model = StableDiffusionPipeline.from_pretrained(
32
- "stabilityai/stable-diffusion-2-1",
33
- torch_dtype=torch.float16
34
  ).to(device)
35
  except Exception as e:
36
  print("⚠️ Could not load Stable Diffusion:", e)
37
  sd_model = None
38
 
39
-
 
 
40
  SYSTEM_PROMPT = """You are ZEN Research Assistant.
41
  You can respond in ONE of these forms:
42
  - Image β†’ {"type":"image","prompt":"<prompt>"}
43
- - Chart β†’ {"type":"chart","title":"<chart title>","data":[{"x":[...], "y":[...], "label":"<series>"}]}
44
  - Simulation β†’ {"type":"simulation","topic":"<title>","steps":["...", "..."]}
45
  - Text β†’ plain conversation.
46
-
47
- Rules:
48
- - Use JSON ONLY for image, chart, or simulation.
49
- - Simulation = imaginative thought experiment, 3–6 steps.
50
  """
51
 
52
  def query_llm(prompt, history, persona):
@@ -57,10 +45,9 @@ def query_llm(prompt, history, persona):
57
  input_text += f"User: {u}\nAssistant: {a}\n"
58
  input_text += f"User: {prompt}\nAssistant:"
59
 
60
- out = chat_model(input_text, max_new_tokens=300, do_sample=True, temperature=0.7)
61
  return out[0]["generated_text"].strip()
62
 
63
-
64
  def multimodal_chat(user_msg, history, persona):
65
  history = history or []
66
  assistant_content = query_llm(user_msg, history, persona)
@@ -74,7 +61,7 @@ def multimodal_chat(user_msg, history, persona):
74
  img = sd_model(parsed["prompt"]).images[0]
75
  history.append([user_msg, "πŸ–ΌοΈ Generated image below."])
76
  else:
77
- history.append([user_msg, "⚠️ Image generation not available on this hardware."])
78
 
79
  elif parsed.get("type") == "chart":
80
  fig = go.Figure()
@@ -97,11 +84,12 @@ def multimodal_chat(user_msg, history, persona):
97
 
98
  return history, img, fig
99
 
100
-
 
 
101
  with gr.Blocks(css="style.css") as demo:
102
- gr.Markdown("🧠 **ZEN Research Lab (Adaptive Edition)**", elem_id="zen-header")
103
 
104
- # Capabilities banner
105
  cap_text = "βœ… Text βœ… Charts βœ… Simulation"
106
  if sd_model is not None:
107
  cap_text += " βœ… Images"
@@ -132,18 +120,17 @@ with gr.Blocks(css="style.css") as demo:
132
  user_msg.submit(respond, inputs=[user_msg, chatbot, persona],
133
  outputs=[chatbot, img_out, chart_out])
134
 
135
- # Example starter buttons
136
  with gr.Accordion("✨ Try these examples"):
137
- with gr.Row():
138
- gr.Examples(
139
- examples=[
140
- ["Draw a futuristic city skyline at night"],
141
- ["Simulate first contact with an alien civilization"],
142
- ["Make a chart of AI adoption from 2010 to 2030"],
143
- ["Explain quantum entanglement in simple terms"],
144
- ],
145
- inputs=[user_msg]
146
- )
147
 
148
  if __name__ == "__main__":
149
  demo.queue(max_size=50).launch()
 
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
8
 
9
  # ----------------------------
10
+ # Load lightweight text model
11
  # ----------------------------
12
+ text_model_name = "google/flan-t5-small" # tiny, CPU-friendly
13
+ chat_model = pipeline("text2text-generation", model=text_model_name, device=0 if device=="cuda" else -1)
 
 
 
 
 
 
 
 
 
14
 
15
  # ----------------------------
16
  # Try to load Stable Diffusion (only if GPU)
 
20
  try:
21
  from diffusers import StableDiffusionPipeline
22
  sd_model = StableDiffusionPipeline.from_pretrained(
23
+ "stabilityai/stable-diffusion-2-1-base"
 
24
  ).to(device)
25
  except Exception as e:
26
  print("⚠️ Could not load Stable Diffusion:", e)
27
  sd_model = None
28
 
29
+ # ----------------------------
30
+ # Core assistant logic
31
+ # ----------------------------
32
  SYSTEM_PROMPT = """You are ZEN Research Assistant.
33
  You can respond in ONE of these forms:
34
  - Image β†’ {"type":"image","prompt":"<prompt>"}
35
+ - Chart β†’ {"type":"chart","title":"<chart title>","data":[{"x":[1,2,3], "y":[2,4,6], "label":"Example"}]}
36
  - Simulation β†’ {"type":"simulation","topic":"<title>","steps":["...", "..."]}
37
  - Text β†’ plain conversation.
 
 
 
 
38
  """
39
 
40
  def query_llm(prompt, history, persona):
 
45
  input_text += f"User: {u}\nAssistant: {a}\n"
46
  input_text += f"User: {prompt}\nAssistant:"
47
 
48
+ out = chat_model(input_text, max_new_tokens=256)
49
  return out[0]["generated_text"].strip()
50
 
 
51
  def multimodal_chat(user_msg, history, persona):
52
  history = history or []
53
  assistant_content = query_llm(user_msg, history, persona)
 
61
  img = sd_model(parsed["prompt"]).images[0]
62
  history.append([user_msg, "πŸ–ΌοΈ Generated image below."])
63
  else:
64
+ history.append([user_msg, "⚠️ Image generation requires GPU."])
65
 
66
  elif parsed.get("type") == "chart":
67
  fig = go.Figure()
 
84
 
85
  return history, img, fig
86
 
87
+ # ----------------------------
88
+ # Gradio UI
89
+ # ----------------------------
90
  with gr.Blocks(css="style.css") as demo:
91
+ gr.Markdown("🧠 **ZEN Research Lab (Light Mode)**", elem_id="zen-header")
92
 
 
93
  cap_text = "βœ… Text βœ… Charts βœ… Simulation"
94
  if sd_model is not None:
95
  cap_text += " βœ… Images"
 
120
  user_msg.submit(respond, inputs=[user_msg, chatbot, persona],
121
  outputs=[chatbot, img_out, chart_out])
122
 
123
+ # Examples
124
  with gr.Accordion("✨ Try these examples"):
125
+ gr.Examples(
126
+ examples=[
127
+ ["Draw a futuristic city skyline at night"],
128
+ ["Simulate first contact with an alien civilization"],
129
+ ["Make a chart of AI adoption from 2010 to 2030"],
130
+ ["Explain quantum entanglement in simple terms"],
131
+ ],
132
+ inputs=[user_msg]
133
+ )
 
134
 
135
  if __name__ == "__main__":
136
  demo.queue(max_size=50).launch()