willsh1997 commited on
Commit
84a6755
·
verified ·
1 Parent(s): 7bfaef7
Files changed (1) hide show
  1. app.py +35 -36
app.py CHANGED
@@ -66,47 +66,54 @@ from multiprocessing import freeze_support
66
  import gradio as gr
67
  import numpy as np
68
 
69
- from vllm import LLM
 
70
 
71
- @spaces.GPU
72
- def initialize_model():
73
- """Initialize the model - called after proper multiprocessing setup"""
74
- llama3_model_id = "shuyuej/Llama-3.2-1B-Instruct-GPTQ"
75
- llama3_pipe = LLM(
76
- model=llama3_model_id,
77
- quantization="gptq",
78
- gpu_memory_utilization=0.5,
79
- max_model_len=1024
80
- )
81
- return llama3_pipe
82
 
83
- # Global variable to hold the model
84
- llama3_pipe = None
85
 
86
- default_sys_prompt = """You are a helpful chatbot. You respond very conversationally, and help the end user as best as you can."""
 
 
 
 
 
 
87
 
88
  @spaces.GPU
89
  def llama_QA(message_history, system_prompt: str):
90
  """
91
- stupid func for asking llama a question and then getting an answer
92
  inputs:
93
- - input_question [str]: question for llama to answer
 
94
  outputs:
95
  - response [str]: llama's response
96
  """
97
  global llama3_pipe
98
 
99
- # set max gen to 512
100
- sampling_params = llama3_pipe.get_default_sampling_params()
101
- sampling_params.max_tokens = 512
102
-
 
103
  input_message_history = [{"role": "system", "content": system_prompt}]
104
  input_message_history.extend(message_history)
105
-
106
- outputs = llama3_pipe.chat(input_message_history, sampling_params)[0].outputs[0].text
107
- # message_history.append({"role": "assistant", "content": outputs})
108
-
109
- return outputs
 
 
 
 
 
 
 
 
 
110
 
111
 
112
  @dataclass
@@ -257,13 +264,5 @@ def create_demo():
257
 
258
  return demo
259
 
260
-
261
- if __name__ == "__main__":
262
- freeze_support() # Add this for Windows compatibility
263
-
264
- # Initialize the model after freeze_support
265
- llama3_pipe = initialize_model()
266
-
267
- # Create and launch the demo
268
- demo = create_demo()
269
- demo.launch()
 
66
  import gradio as gr
67
  import numpy as np
68
 
69
+ from transformers import pipeline
70
+ import torch
71
 
72
+ default_sys_prompt = """You are a helpful chatbot. You respond very conversationally, and help the end user as best as you can."""
 
 
 
 
 
 
 
 
 
 
73
 
74
+ llama3_model_id = "shuyuej/Llama-3.2-1B-Instruct-GPTQ"
 
75
 
76
+ llama3_pipe = pipeline(
77
+ "text-generation",
78
+ model=llama3_model_id,
79
+ device_map="auto",
80
+ torch_dtype=torch.float16,
81
+ max_new_tokens=512
82
+ )
83
 
84
  @spaces.GPU
85
  def llama_QA(message_history, system_prompt: str):
86
  """
87
+ Function for asking llama a question and then getting an answer
88
  inputs:
89
+ - message_history [list]: conversation history
90
+ - system_prompt [str]: system prompt for the model
91
  outputs:
92
  - response [str]: llama's response
93
  """
94
  global llama3_pipe
95
 
96
+ # Lazy initialization - only load model when first called
97
+ if llama3_pipe is None:
98
+ llama3_pipe = initialize_model()
99
+
100
+ # Prepare the message history
101
  input_message_history = [{"role": "system", "content": system_prompt}]
102
  input_message_history.extend(message_history)
103
+
104
+ # Generate response using pipeline
105
+ outputs = llama3_pipe(
106
+ input_message_history,
107
+ max_new_tokens=512,
108
+ do_sample=True,
109
+ temperature=0.7,
110
+ top_p=0.9
111
+ )
112
+
113
+ # Extract the response text
114
+ response = outputs[0]["generated_text"][-1]["content"]
115
+
116
+ return response
117
 
118
 
119
  @dataclass
 
264
 
265
  return demo
266
 
267
+ demo = create_demo()
268
+ demo.launch()