ZENLLC commited on
Commit
4db636e
ยท
verified ยท
1 Parent(s): dac56f8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -18
app.py CHANGED
@@ -1,20 +1,35 @@
1
- import gradio as gr, json, plotly.graph_objects as go
2
  from transformers import pipeline
3
- from diffusers import StableDiffusionPipeline
4
- import torch
5
 
6
  # ----------------------------
7
- # Load models once on startup
8
  # ----------------------------
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
11
- # Text model (fast chat)
12
- chat_model = pipeline("text-generation", model="HuggingFaceH4/zephyr-7b-beta", device=0 if device=="cuda" else -1)
 
 
 
 
 
 
 
13
 
14
- # Image model (stable diffusion)
15
- sd_model = StableDiffusionPipeline.from_pretrained(
16
- "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16 if device=="cuda" else torch.float32
17
- ).to(device)
 
 
 
 
 
 
 
 
 
 
18
 
19
 
20
  SYSTEM_PROMPT = """You are ZEN Research Assistant.
@@ -22,16 +37,14 @@ You can respond in ONE of these forms:
22
  - Image โ†’ {"type":"image","prompt":"<prompt>"}
23
  - Chart โ†’ {"type":"chart","title":"<chart title>","data":[{"x":[...], "y":[...], "label":"<series>"}]}
24
  - Simulation โ†’ {"type":"simulation","topic":"<title>","steps":["...", "..."]}
25
- - Text โ†’ plain conversation, explanation, or reasoning.
26
 
27
  Rules:
28
  - Use JSON ONLY for image, chart, or simulation.
29
  - Simulation = imaginative thought experiment, 3โ€“6 steps.
30
- - If not sure, default to conversational text.
31
  """
32
 
33
  def query_llm(prompt, history, persona):
34
- # Construct conversation
35
  input_text = SYSTEM_PROMPT
36
  if persona != "Default":
37
  input_text += f"\nPersona: {persona}\n"
@@ -39,7 +52,7 @@ def query_llm(prompt, history, persona):
39
  input_text += f"User: {u}\nAssistant: {a}\n"
40
  input_text += f"User: {prompt}\nAssistant:"
41
 
42
- out = chat_model(input_text, max_new_tokens=400, do_sample=True, temperature=0.7)
43
  return out[0]["generated_text"].split("Assistant:")[-1].strip()
44
 
45
 
@@ -52,8 +65,11 @@ def multimodal_chat(user_msg, history, persona):
52
  parsed = json.loads(assistant_content)
53
 
54
  if parsed.get("type") == "image":
55
- img = sd_model(parsed["prompt"]).images[0]
56
- history.append([user_msg, "๐Ÿ–ผ๏ธ Generated image below."])
 
 
 
57
 
58
  elif parsed.get("type") == "chart":
59
  fig = go.Figure()
@@ -65,7 +81,7 @@ def multimodal_chat(user_msg, history, persona):
65
  history.append([user_msg, parsed.get("title","Chart below")])
66
 
67
  elif parsed.get("type") == "simulation":
68
- steps = "\n".join([f"โ†’ {s}" for s in parsed["steps"]])
69
  history.append([user_msg, f"๐Ÿ”ฎ Simulation: {parsed.get('topic','Exploration')}\n{steps}"])
70
 
71
  else:
@@ -78,7 +94,7 @@ def multimodal_chat(user_msg, history, persona):
78
 
79
 
80
  with gr.Blocks(css="style.css") as demo:
81
- gr.Markdown("๐Ÿง  **ZEN Research Lab (API-free Edition)** โ€” Explore, simulate, and create", elem_id="zen-header")
82
 
83
  persona = gr.Dropdown(["Default","Analyst","Artist","Futurist","Philosopher"], label="Mode", value="Default")
84
  chatbot = gr.Chatbot(label="Conversation", height=400)
 
1
+ import gradio as gr, json, plotly.graph_objects as go, torch
2
  from transformers import pipeline
 
 
3
 
4
  # ----------------------------
5
+ # Detect device
6
  # ----------------------------
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
8
 
9
+ # ----------------------------
10
+ # Load text model (fallbacks)
11
+ # ----------------------------
12
+ if device == "cuda":
13
+ text_model_name = "HuggingFaceH4/zephyr-7b-beta"
14
+ else:
15
+ text_model_name = "google/flan-t5-base" # CPU-friendly
16
+
17
+ chat_model = pipeline("text-generation", model=text_model_name, device=0 if device=="cuda" else -1)
18
 
19
+ # ----------------------------
20
+ # Try to load Stable Diffusion (only if GPU)
21
+ # ----------------------------
22
+ sd_model = None
23
+ if device == "cuda":
24
+ try:
25
+ from diffusers import StableDiffusionPipeline
26
+ sd_model = StableDiffusionPipeline.from_pretrained(
27
+ "stabilityai/stable-diffusion-2-1",
28
+ torch_dtype=torch.float16
29
+ ).to(device)
30
+ except Exception as e:
31
+ print("โš ๏ธ Could not load Stable Diffusion:", e)
32
+ sd_model = None
33
 
34
 
35
  SYSTEM_PROMPT = """You are ZEN Research Assistant.
 
37
  - Image โ†’ {"type":"image","prompt":"<prompt>"}
38
  - Chart โ†’ {"type":"chart","title":"<chart title>","data":[{"x":[...], "y":[...], "label":"<series>"}]}
39
  - Simulation โ†’ {"type":"simulation","topic":"<title>","steps":["...", "..."]}
40
+ - Text โ†’ plain conversation.
41
 
42
  Rules:
43
  - Use JSON ONLY for image, chart, or simulation.
44
  - Simulation = imaginative thought experiment, 3โ€“6 steps.
 
45
  """
46
 
47
  def query_llm(prompt, history, persona):
 
48
  input_text = SYSTEM_PROMPT
49
  if persona != "Default":
50
  input_text += f"\nPersona: {persona}\n"
 
52
  input_text += f"User: {u}\nAssistant: {a}\n"
53
  input_text += f"User: {prompt}\nAssistant:"
54
 
55
+ out = chat_model(input_text, max_new_tokens=300, do_sample=True, temperature=0.7)
56
  return out[0]["generated_text"].split("Assistant:")[-1].strip()
57
 
58
 
 
65
  parsed = json.loads(assistant_content)
66
 
67
  if parsed.get("type") == "image":
68
+ if sd_model is not None:
69
+ img = sd_model(parsed["prompt"]).images[0]
70
+ history.append([user_msg, "๐Ÿ–ผ๏ธ Generated image below."])
71
+ else:
72
+ history.append([user_msg, "โš ๏ธ Image generation not available on this hardware."])
73
 
74
  elif parsed.get("type") == "chart":
75
  fig = go.Figure()
 
81
  history.append([user_msg, parsed.get("title","Chart below")])
82
 
83
  elif parsed.get("type") == "simulation":
84
+ steps = "\n".join([f"โ†’ {s}" for s in parsed.get("steps",[])])
85
  history.append([user_msg, f"๐Ÿ”ฎ Simulation: {parsed.get('topic','Exploration')}\n{steps}"])
86
 
87
  else:
 
94
 
95
 
96
  with gr.Blocks(css="style.css") as demo:
97
+ gr.Markdown("๐Ÿง  **ZEN Research Lab (Adaptive Edition)** โ€” works everywhere, GPU unlocks extra powers", elem_id="zen-header")
98
 
99
  persona = gr.Dropdown(["Default","Analyst","Artist","Futurist","Philosopher"], label="Mode", value="Default")
100
  chatbot = gr.Chatbot(label="Conversation", height=400)