Nihal2000 commited on
Commit
18a5b6f
·
verified ·
1 Parent(s): 4888362

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -45
app.py CHANGED
@@ -1,87 +1,143 @@
1
  import os
2
  import gradio as gr
 
3
  from src.model_manager import ModelManager
4
  from src.inference_engine import InferenceEngine
5
 
6
  ASSETS_DIR = "assets"
 
7
 
8
- # Initialize once
9
- manager = ModelManager(os.path.join(ASSETS_DIR, "models"))
 
10
 
11
- def list_models():
12
- models = manager.get_available_models()
13
- return models
14
 
15
- # Cache loaded engines by model name
16
- _engines = {}
 
 
 
 
17
 
18
- def load_engine(model_name):
19
- if model_name in _engines:
20
- return _engines[model_name]
 
21
  model, tokenizer, config = manager.load_model(model_name)
22
  engine = InferenceEngine(model, tokenizer, config)
23
- _engines[model_name] = engine
24
  return engine
25
 
26
  def chat_fn(message, history, model_name, max_tokens, temperature, top_p, top_k):
 
 
 
 
 
27
  if not model_name:
28
- return history + [[message, "No model selected. Please choose a model."]]
 
 
 
29
  try:
30
  engine = load_engine(model_name)
31
  except Exception as e:
32
- return history + [[message, f"Error loading model: {e}"]]
33
- reply = engine.generate_response(
34
- message,
35
- max_tokens=max_tokens,
36
- temperature=temperature,
37
- top_p=top_p,
38
- top_k=top_k
39
- )
40
- return history + [[message, reply]]
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  def clear_chat():
 
43
  return []
44
 
45
  with gr.Blocks(title="Automotive SLM Chatbot") as demo:
46
  gr.Markdown("# 🚗 Automotive SLM Chatbot (Gradio)")
 
 
47
  with gr.Row():
48
  with gr.Column(scale=3):
49
- chatbot = gr.Chatbot(height=450, label="Chat")
50
- msg = gr.Textbox(placeholder="Ask about automotive topics...", label="Your message")
 
 
 
 
 
 
 
51
  with gr.Row():
52
  send_btn = gr.Button("Send", variant="primary")
53
  clear_btn = gr.Button("Clear")
 
54
  with gr.Column(scale=2):
55
  gr.Markdown("### Model settings")
56
  available = list_models()
57
  if not available:
58
- # Show a friendly message and stop early
59
- import gradio as gr
60
- with gr.Row():
61
- gr.Markdown("No models found in assets/models. Please add .pt/.pth/.onnx files and refresh.")
 
 
 
62
  else:
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  model_dropdown = gr.Dropdown(
64
- choices=available,
65
  value=available[0],
66
  label="Model"
67
  )
68
- max_tokens = gr.Slider(10, 256, value=64, step=1, label="Max tokens")
69
- temperature = gr.Slider(0.1, 1.5, value=0.8, step=0.1, label="Temperature")
70
- top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
71
- top_k = gr.Slider(1, 100, value=50, step=1, label="Top-k")
72
- gr.Markdown("Tip: lower temperature for more deterministic answers.")
73
- # Events
74
- send_evt = send_btn.click(
75
- fn=chat_fn,
76
- inputs=[msg, chatbot, model_dropdown, max_tokens, temperature, top_p, top_k],
77
- outputs=[chatbot]
78
- )
79
- msg.submit(
80
- fn=chat_fn,
81
- inputs=[msg, chatbot, model_dropdown, max_tokens, temperature, top_p, top_k],
82
- outputs=[chatbot]
83
- )
84
- clear_btn.click(clear_chat, inputs=None, outputs=[chatbot])
 
 
 
85
 
86
  if __name__ == "__main__":
87
  demo.launch()
 
1
  import os
2
  import gradio as gr
3
+
4
  from src.model_manager import ModelManager
5
  from src.inference_engine import InferenceEngine
6
 
7
  ASSETS_DIR = "assets"
8
+ MODELS_DIR = os.path.join(ASSETS_DIR, "models")
9
 
10
+ # Ensure directories exist (prevents path issues)
11
+ os.makedirs(ASSETS_DIR, exist_ok=True)
12
+ os.makedirs(MODELS_DIR, exist_ok=True)
13
 
14
+ # Initialize global model manager
15
+ manager = ModelManager(MODELS_DIR)
 
16
 
17
+ # Cache of InferenceEngine per model filename
18
+ _ENGINE_CACHE = {}
19
+
20
+ def list_models():
21
+ """Return available model filenames from assets/models"""
22
+ return manager.get_available_models()
23
 
24
+ def load_engine(model_name: str) -> InferenceEngine:
25
+ """Return a cached InferenceEngine for selected model"""
26
+ if model_name in _ENGINE_CACHE:
27
+ return _ENGINE_CACHE[model_name]
28
  model, tokenizer, config = manager.load_model(model_name)
29
  engine = InferenceEngine(model, tokenizer, config)
30
+ _ENGINE_CACHE[model_name] = engine
31
  return engine
32
 
33
  def chat_fn(message, history, model_name, max_tokens, temperature, top_p, top_k):
34
+ """
35
+ Gradio Chatbot callback.
36
+ - history: list of dicts [{role: "user"/"assistant", content: "..."}, ...]
37
+ - message: latest user message string
38
+ """
39
  if not model_name:
40
+ # Append assistant message indicating the issue
41
+ history = history + [{"role": "assistant", "content": "No model selected. Please choose a model from the right panel."}]
42
+ return history
43
+
44
  try:
45
  engine = load_engine(model_name)
46
  except Exception as e:
47
+ history = history + [{"role": "assistant", "content": f"Error loading model: {e}"}]
48
+ return history
49
+
50
+ try:
51
+ reply = engine.generate_response(
52
+ message,
53
+ max_tokens=int(max_tokens),
54
+ temperature=float(temperature),
55
+ top_p=float(top_p),
56
+ top_k=int(top_k),
57
+ )
58
+ except Exception as e:
59
+ reply = f"An error occurred during generation: {e}"
60
+
61
+ # Append the user and assistant messages in messages format
62
+ history = history + [
63
+ {"role": "user", "content": message},
64
+ {"role": "assistant", "content": reply},
65
+ ]
66
+ return history
67
 
68
  def clear_chat():
69
+ """Reset chat history"""
70
  return []
71
 
72
  with gr.Blocks(title="Automotive SLM Chatbot") as demo:
73
  gr.Markdown("# 🚗 Automotive SLM Chatbot (Gradio)")
74
+ gr.Markdown("Small Language Model for automotive assistance. Select a model and start chatting.")
75
+
76
  with gr.Row():
77
  with gr.Column(scale=3):
78
+ chatbot = gr.Chatbot(
79
+ label="Chat",
80
+ height=500,
81
+ type="messages" # use OpenAI-style messages
82
+ )
83
+ msg = gr.Textbox(
84
+ placeholder="Ask about automotive topics (e.g., tire pressure, check engine light, EV charging)...",
85
+ label="Your message"
86
+ )
87
  with gr.Row():
88
  send_btn = gr.Button("Send", variant="primary")
89
  clear_btn = gr.Button("Clear")
90
+
91
  with gr.Column(scale=2):
92
  gr.Markdown("### Model settings")
93
  available = list_models()
94
  if not available:
95
+ gr.Markdown("No models found in assets/models. Please add .pt/.pth/.onnx files and refresh the Space.")
96
+ # Disabled controls to avoid wiring errors
97
+ model_dropdown = gr.Dropdown(choices=[], value=None, label="Model", interactive=False)
98
+ max_tokens = gr.Slider(10, 256, value=64, step=1, label="Max tokens", interactive=False)
99
+ temperature = gr.Slider(0.1, 1.5, value=0.8, step=0.1, label="Temperature", interactive=False)
100
+ top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p", interactive=False)
101
+ top_k = gr.Slider(1, 100, value=50, step=1, label="Top-k", interactive=False)
102
  else:
103
+ # Optional: show size labels
104
+ def size_mb(path):
105
+ try:
106
+ return os.path.getsize(path) / (1024 * 1024)
107
+ except Exception:
108
+ return 0.0
109
+ labels = []
110
+ for name in available:
111
+ mb = size_mb(os.path.join(MODELS_DIR, name))
112
+ labels.append(f"{name} ({mb:.1f} MB)")
113
+ # Map labels to values so dropdown shows label but value is filename
114
+ choices = list(zip(labels, available))
115
+
116
  model_dropdown = gr.Dropdown(
117
+ choices=choices,
118
  value=available[0],
119
  label="Model"
120
  )
121
+ max_tokens = gr.Slider(10, 256, value=64, step=1, label="Max tokens")
122
+ temperature = gr.Slider(0.1, 1.5, value=0.8, step=0.1, label="Temperature")
123
+ top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
124
+ top_k = gr.Slider(1, 100, value=50, step=1, label="Top-k")
125
+
126
+ gr.Markdown("Tip: Lower temperature and higher top-k/top-p can make answers more focused.")
127
+
128
+ # Wire events only if models are available
129
+ if available:
130
+ send_btn.click(
131
+ fn=chat_fn,
132
+ inputs=[msg, chatbot, model_dropdown, max_tokens, temperature, top_p, top_k],
133
+ outputs=[chatbot]
134
+ )
135
+ msg.submit(
136
+ fn=chat_fn,
137
+ inputs=[msg, chatbot, model_dropdown, max_tokens, temperature, top_p, top_k],
138
+ outputs=[chatbot]
139
+ )
140
+ clear_btn.click(clear_chat, None, chatbot)
141
 
142
  if __name__ == "__main__":
143
  demo.launch()