airsltd commited on
Commit
05e9938
·
verified ·
1 Parent(s): 8eae1e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -6
app.py CHANGED
@@ -2,6 +2,7 @@
2
  """
3
  FastAPI application for FunctionGemma with HuggingFace login support.
4
  This file is designed to be run with: uvicorn app:app --host 0.0.0.0 --port 7860
 
5
  """
6
 
7
  import os
@@ -14,11 +15,12 @@ from huggingface_hub import login
14
  # Global variables
15
  model_name = None
16
  pipe = None
 
17
  app = FastAPI(title="FunctionGemma API", version="1.0.0")
18
 
19
  def check_and_download_model():
20
  """Check if model exists in cache, if not download it"""
21
- global model_name
22
 
23
  # Use TinyLlama - a fully public model
24
  # model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
@@ -32,6 +34,7 @@ def check_and_download_model():
32
 
33
  if snapshot_path.exists() and any(snapshot_path.iterdir()):
34
  print(f"✓ Model {model_name} already exists in cache")
 
35
  return model_name, cache_dir
36
 
37
  print(f"✗ Model {model_name} not found in cache")
@@ -74,13 +77,16 @@ def check_and_download_model():
74
 
75
  def initialize_pipeline():
76
  """Initialize the pipeline with the model"""
77
- global pipe, model_name
78
 
79
  if model_name is None:
80
  model_name, _ = check_and_download_model()
81
 
 
 
 
82
  print(f"Initializing pipeline with {model_name}...")
83
- pipe = pipeline("text-generation", model=model_name)
84
  print("✓ Pipeline initialized successfully!")
85
 
86
  # API Endpoints
@@ -103,7 +109,7 @@ def generate_text(prompt: str = "Who are you?"):
103
  initialize_pipeline()
104
 
105
  messages = [{"role": "user", "content": prompt}]
106
- result = pipe(messages, max_new_tokens=100)
107
  return {"response": result[0]["generated_text"]}
108
 
109
  @app.post("/chat")
@@ -136,7 +142,7 @@ def openai_chat_completions(request: dict):
136
 
137
  messages = request.get("messages", [])
138
  model = request.get("model", model_name)
139
- max_tokens = request.get("max_tokens", 1000)
140
  temperature = request.get("temperature", 0.7)
141
 
142
  print('\n\n request')
@@ -188,11 +194,28 @@ def openai_chat_completions(request: dict):
188
  }
189
  ],
190
  "usage": {
191
- "prompt_tokens": 0, # Would need tokenizer to calculate
192
  "completion_tokens": 0,
193
  "total_tokens": 0
194
  }
195
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  print('\n\n return_json')
197
  print(return_json)
198
  print('return over! \n\n')
 
2
  """
3
  FastAPI application for FunctionGemma with HuggingFace login support.
4
  This file is designed to be run with: uvicorn app:app --host 0.0.0.0 --port 7860
5
+ 修复:增加token计算
6
  """
7
 
8
  import os
 
15
  # Global variables
16
  model_name = None
17
  pipe = None
18
+ tokenizer = None # Add global tokenizer
19
  app = FastAPI(title="FunctionGemma API", version="1.0.0")
20
 
21
  def check_and_download_model():
22
  """Check if model exists in cache, if not download it"""
23
+ global model_name, tokenizer # Include tokenizer in global
24
 
25
  # Use TinyLlama - a fully public model
26
  # model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
 
34
 
35
  if snapshot_path.exists() and any(snapshot_path.iterdir()):
36
  print(f"✓ Model {model_name} already exists in cache")
37
+ tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir) # Load tokenizer if model exists
38
  return model_name, cache_dir
39
 
40
  print(f"✗ Model {model_name} not found in cache")
 
77
 
78
  def initialize_pipeline():
79
  """Initialize the pipeline with the model"""
80
+ global pipe, model_name, tokenizer # Include tokenizer in global
81
 
82
  if model_name is None:
83
  model_name, _ = check_and_download_model()
84
 
85
+ if tokenizer is None: # Ensure tokenizer is loaded
86
+ tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="./my_model_cache")
87
+
88
  print(f"Initializing pipeline with {model_name}...")
89
+ pipe = pipeline("text-generation", model=model_name, tokenizer=tokenizer) # Pass tokenizer to pipeline
90
  print("✓ Pipeline initialized successfully!")
91
 
92
  # API Endpoints
 
109
  initialize_pipeline()
110
 
111
  messages = [{"role": "user", "content": prompt}]
112
+ result = pipe(messages, max_new_tokens=1000)
113
  return {"response": result[0]["generated_text"]}
114
 
115
  @app.post("/chat")
 
142
 
143
  messages = request.get("messages", [])
144
  model = request.get("model", model_name)
145
+ max_tokens = request.get("max_tokens", 100)
146
  temperature = request.get("temperature", 0.7)
147
 
148
  print('\n\n request')
 
194
  }
195
  ],
196
  "usage": {
197
+ "prompt_tokens": 0,
198
  "completion_tokens": 0,
199
  "total_tokens": 0
200
  }
201
  }
202
+
203
+ # Calculate prompt tokens
204
+ if tokenizer:
205
+ prompt_text = ""
206
+ for message in messages:
207
+ prompt_text += message.get("content", "") + " "
208
+ prompt_tokens = len(tokenizer.encode(prompt_text.strip()))
209
+ return_json["usage"]["prompt_tokens"] = prompt_tokens
210
+
211
+ # Calculate completion tokens
212
+ if tokenizer and result["generations"]:
213
+ completion_text = result["generations"][0][0]["text"]
214
+ completion_tokens = len(tokenizer.encode(completion_text))
215
+ return_json["usage"]["completion_tokens"] = completion_tokens
216
+
217
+ return_json["usage"]["total_tokens"] = return_json["usage"]["prompt_tokens"] + return_json["usage"]["completion_tokens"]
218
+
219
  print('\n\n return_json')
220
  print(return_json)
221
  print('return over! \n\n')