aadya1762 commited on
Commit
d24a753
·
1 Parent(s): e4ef2eb

Stream LLM responses

Browse files
Files changed (2) hide show
  1. gemmademo/_chat.py +14 -5
  2. gemmademo/_model.py +24 -12
gemmademo/_chat.py CHANGED
@@ -18,8 +18,9 @@ class GradioChat:
18
  self.model_options = model_options
19
  self.task_options = task_options
20
  self.current_model_name = "gemma-2b-it" # Default model
21
- self.model = self._load_model(self.current_model_name)
22
  self.current_task_name = "Question Answering" # Default task
 
 
23
  self.prompt_manager = self._load_task(self.current_task_name)
24
 
25
  def _load_model(self, model_name: str):
@@ -44,15 +45,23 @@ class GradioChat:
44
 
45
  # Generate response using updated model & prompt manager
46
  prompt = self.prompt_manager.get_prompt(user_input=message)
47
- response = self.model.generate_response(prompt)
48
- return response
49
 
50
  chat_interface = gr.ChatInterface(
51
  chat_fn,
52
  textbox=gr.Textbox(placeholder="Ask me something...", container=False),
53
  additional_inputs=[
54
- gr.Dropdown(choices=self.model_options, value=self.current_model_name, label="Select Gemma Model"),
55
- gr.Dropdown(choices=self.task_options, value=self.current_task_name, label="Select Task"),
 
 
 
 
 
 
 
 
56
  ],
57
  )
58
  chat_interface.launch()
 
18
  self.model_options = model_options
19
  self.task_options = task_options
20
  self.current_model_name = "gemma-2b-it" # Default model
 
21
  self.current_task_name = "Question Answering" # Default task
22
+
23
+ self.model = self._load_model(self.current_model_name)
24
  self.prompt_manager = self._load_task(self.current_task_name)
25
 
26
  def _load_model(self, model_name: str):
 
45
 
46
  # Generate response using updated model & prompt manager
47
  prompt = self.prompt_manager.get_prompt(user_input=message)
48
+ response_stream = self.model.generate_response(prompt)
49
+ yield from response_stream
50
 
51
  chat_interface = gr.ChatInterface(
52
  chat_fn,
53
  textbox=gr.Textbox(placeholder="Ask me something...", container=False),
54
  additional_inputs=[
55
+ gr.Dropdown(
56
+ choices=self.model_options,
57
+ value=self.current_model_name,
58
+ label="Select Gemma Model",
59
+ ),
60
+ gr.Dropdown(
61
+ choices=self.task_options,
62
+ value=self.current_task_name,
63
+ label="Select Task",
64
+ ),
65
  ],
66
  )
67
  chat_interface.launch()
gemmademo/_model.py CHANGED
@@ -59,6 +59,7 @@ class LlamaCppGemmaModel:
59
  """
60
  self.name = name
61
  self.model = None # Instance of Llama from llama.cpp
 
62
 
63
  def load_model(self, n_ctx: int = 2048, n_gpu_layers: int = 0):
64
  """
@@ -73,23 +74,25 @@ class LlamaCppGemmaModel:
73
  raise ValueError(f"Model {self.name} is not available.")
74
 
75
  model_path = model_info["model_path"]
76
-
77
  # If the model file doesn't exist, download it.
78
  if not os.path.exists(model_path):
79
  os.makedirs(os.path.dirname(model_path), exist_ok=True)
80
  repo_id = model_info.get("repo_id")
81
  filename = model_info.get("filename")
82
-
83
  if repo_id is None or filename is None:
84
- raise ValueError("Repository ID or filename is missing for model download.")
85
-
 
 
86
  downloaded_path = hf_hub_download(
87
  repo_id=repo_id,
88
  filename=filename,
89
  local_dir=os.path.dirname(model_path),
90
  local_dir_use_symlinks=False,
91
  )
92
-
93
  if downloaded_path != model_path:
94
  os.rename(downloaded_path, model_path)
95
 
@@ -101,7 +104,9 @@ class LlamaCppGemmaModel:
101
  )
102
  return self
103
 
104
- def generate_response(self, prompt: str, max_tokens: int = 512, temperature: float = 0.7) -> str:
 
 
105
  """
106
  Generate a response using the llama.cpp model.
107
 
@@ -110,18 +115,25 @@ class LlamaCppGemmaModel:
110
  max_tokens (int): Maximum number of tokens to generate.
111
  temperature (float): Sampling temperature (higher = more creative).
112
 
113
- Returns:
114
- str: Generated response text.
115
  """
116
  if self.model is None:
117
  self.load_model()
118
 
119
- response = self.model(
120
- prompt,
 
 
121
  max_tokens=max_tokens,
122
  temperature=temperature,
 
123
  )
124
- return response["choices"][0]["text"].strip()
 
 
 
 
125
 
126
  def get_model_info(self) -> Dict:
127
  """
@@ -139,4 +151,4 @@ class LlamaCppGemmaModel:
139
  Returns:
140
  str: Model name.
141
  """
142
- return self.name
 
59
  """
60
  self.name = name
61
  self.model = None # Instance of Llama from llama.cpp
62
+ self.messages = []
63
 
64
  def load_model(self, n_ctx: int = 2048, n_gpu_layers: int = 0):
65
  """
 
74
  raise ValueError(f"Model {self.name} is not available.")
75
 
76
  model_path = model_info["model_path"]
77
+
78
  # If the model file doesn't exist, download it.
79
  if not os.path.exists(model_path):
80
  os.makedirs(os.path.dirname(model_path), exist_ok=True)
81
  repo_id = model_info.get("repo_id")
82
  filename = model_info.get("filename")
83
+
84
  if repo_id is None or filename is None:
85
+ raise ValueError(
86
+ "Repository ID or filename is missing for model download."
87
+ )
88
+
89
  downloaded_path = hf_hub_download(
90
  repo_id=repo_id,
91
  filename=filename,
92
  local_dir=os.path.dirname(model_path),
93
  local_dir_use_symlinks=False,
94
  )
95
+
96
  if downloaded_path != model_path:
97
  os.rename(downloaded_path, model_path)
98
 
 
104
  )
105
  return self
106
 
107
+ def generate_response(
108
+ self, prompt: str, max_tokens: int = 512, temperature: float = 0.7
109
+ ):
110
  """
111
  Generate a response using the llama.cpp model.
112
 
 
115
  max_tokens (int): Maximum number of tokens to generate.
116
  temperature (float): Sampling temperature (higher = more creative).
117
 
118
+ Yields:
119
+ str: Generated response text as a stream.
120
  """
121
  if self.model is None:
122
  self.load_model()
123
 
124
+ self.messages.append({"role": "user", "content": prompt})
125
+
126
+ response_stream = self.model.create_chat_completion(
127
+ messages=self.messages,
128
  max_tokens=max_tokens,
129
  temperature=temperature,
130
+ stream=True,
131
  )
132
+
133
+ for chunk in response_stream:
134
+ delta = chunk["choices"][0]["delta"]
135
+ if "content" in delta:
136
+ yield delta["content"].strip()
137
 
138
  def get_model_info(self) -> Dict:
139
  """
 
151
  Returns:
152
  str: Model name.
153
  """
154
+ return self.name