airsltd commited on
Commit
383b979
·
verified ·
1 Parent(s): 5392a8d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -117
app.py CHANGED
@@ -1,19 +1,25 @@
1
  #!/usr/bin/env python3
2
  """
3
- Combined application that automatically downloads the model if needed and starts the FastAPI server.
 
4
  """
5
 
6
  import os
7
  import sys
8
  from pathlib import Path
 
 
 
 
 
 
 
 
9
 
10
- # Check if model exists, if not download it
11
  def check_and_download_model():
12
  """Check if model exists in cache, if not download it"""
13
- from transformers import AutoTokenizer, AutoModelForCausalLM
14
- from huggingface_hub import login
15
 
16
- # 下一步测试 mlx-community/functiongemma-270m-it-4bit
17
  # Use TinyLlama - a fully public model
18
  # model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
19
  model_name = "unsloth/functiongemma-270m-it"
@@ -65,121 +71,128 @@ def check_and_download_model():
65
  print("3. Network connection issues")
66
  sys.exit(1)
67
 
68
- def main():
69
- """Main function to start the application"""
70
- print("=" * 60)
71
- print("FunctionGemma FastAPI Server")
72
- print("=" * 60)
73
-
74
- # Check and download model if needed
75
- model_name, cache_dir = check_and_download_model()
76
-
77
- # Now import and start the FastAPI app
78
- print("\nStarting FastAPI server...")
79
-
80
- from fastapi import FastAPI
81
- from transformers import pipeline
82
 
83
- app = FastAPI(title="FunctionGemma API", version="1.0.0")
 
84
 
85
- # Initialize pipeline
86
  print(f"Initializing pipeline with {model_name}...")
87
  pipe = pipeline("text-generation", model=model_name)
88
  print("✓ Pipeline initialized successfully!")
89
-
90
- @app.get("/")
91
- def greet_json():
92
- return {
93
- "message": "FunctionGemma API is running!",
94
- "model": model_name,
95
- "status": "ready"
96
- }
97
-
98
- @app.get("/health")
99
- def health_check():
100
- return {"status": "healthy", "model": model_name}
101
-
102
- @app.get("/generate")
103
- def generate_text(prompt: str = "Who are you?"):
104
- """Generate text using the model"""
105
- messages = [{"role": "user", "content": prompt}]
106
- result = pipe(messages, max_new_tokens=100)
107
- return {"response": result[0]["generated_text"]}
108
-
109
- @app.post("/chat")
110
- def chat_completion(messages: list):
111
- """Chat completion endpoint"""
112
- result = pipe(messages, max_new_tokens=200)
113
- return {"response": result[0]["generated_text"]}
114
-
115
- @app.post("/v1/chat/completions")
116
- def openai_chat_completions(request: dict):
117
- print('\n\n request')
118
- print(request)
119
- """
120
- OpenAI-compatible chat completions endpoint
121
- Expected request format:
122
- {
123
- "model": "google/gemma-2b-it",
124
- "messages": [
125
- {"role": "user", "content": "Hello"}
126
- ],
127
- "max_tokens": 100,
128
- "temperature": 0.7
129
- }
130
- """
131
- import time
132
-
133
- messages = request.get("messages", [])
134
- model = request.get("model", model_name)
135
- max_tokens = request.get("max_tokens", 100)
136
- temperature = request.get("temperature", 0.7)
137
- print('\n\n messages')
138
- print(messages)
139
- print('\n\n model')
140
- print(model)
141
- print('\n\n max_tokens')
142
- print(max_tokens)
143
- print('\n\n temperature')
144
- print(temperature)
145
-
146
- # Generate response
147
- result = pipe(
148
- messages,
149
- max_new_tokens=max_tokens,
150
- # temperature=temperature
151
- )
152
- print('asdfasdfasdfasdf')
153
-
154
- completion_id = f"chatcmpl-{int(time.time())}"
155
- created = int(time.time())
156
-
157
- return {
158
- "id": completion_id,
159
- "object": "chat.completion",
160
- "created": created,
161
- "model": model,
162
- "choices": [
163
- {
164
- "index": 0,
165
- "message": {
166
- "role": "assistant",
167
- "content": result[0]["generated_text"]
168
- },
169
- "finish_reason": "stop"
170
- }
171
- ],
172
- "usage": {
173
- "prompt_tokens": 0, # Would need tokenizer to calculate
174
- "completion_tokens": 0,
175
- "total_tokens": 0
 
 
 
 
 
176
  }
 
 
 
 
 
177
  }
178
-
179
- # Run the server
180
- import uvicorn
 
 
 
 
 
 
 
 
181
  print("\n" + "=" * 60)
182
- print("Server starting at http://localhost:8000")
183
  print("Available endpoints:")
184
  print(" GET / - Welcome message")
185
  print(" GET /health - Health check")
@@ -187,8 +200,3 @@ def main():
187
  print(" POST /chat - Chat completion")
188
  print(" POST /v1/chat/completions - OpenAI-compatible endpoint")
189
  print("=" * 60 + "\n")
190
-
191
- uvicorn.run(app, host="0.0.0.0", port=7860)
192
-
193
- if __name__ == "__main__":
194
- main()
 
1
  #!/usr/bin/env python3
2
  """
3
+ FastAPI application for FunctionGemma with HuggingFace login support.
4
+ This file is designed to be run with: uvicorn app:app --host 0.0.0.0 --port 7860
5
  """
6
 
7
  import os
8
  import sys
9
  from pathlib import Path
10
+ from fastapi import FastAPI
11
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
12
+ from huggingface_hub import login
13
+
14
+ # Global variables
15
+ model_name = None
16
+ pipe = None
17
+ app = FastAPI(title="FunctionGemma API", version="1.0.0")
18
 
 
19
  def check_and_download_model():
20
  """Check if model exists in cache, if not download it"""
21
+ global model_name
 
22
 
 
23
  # Use TinyLlama - a fully public model
24
  # model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
25
  model_name = "unsloth/functiongemma-270m-it"
 
71
  print("3. Network connection issues")
72
  sys.exit(1)
73
 
74
+ def initialize_pipeline():
75
+ """Initialize the pipeline with the model"""
76
+ global pipe, model_name
 
 
 
 
 
 
 
 
 
 
 
77
 
78
+ if model_name is None:
79
+ model_name, _ = check_and_download_model()
80
 
 
81
  print(f"Initializing pipeline with {model_name}...")
82
  pipe = pipeline("text-generation", model=model_name)
83
  print("✓ Pipeline initialized successfully!")
84
+
85
+ # API Endpoints
86
+ @app.get("/")
87
+ def greet_json():
88
+ return {
89
+ "message": "FunctionGemma API is running!",
90
+ "model": model_name,
91
+ "status": "ready"
92
+ }
93
+
94
+ @app.get("/health")
95
+ def health_check():
96
+ return {"status": "healthy", "model": model_name}
97
+
98
+ @app.get("/generate")
99
+ def generate_text(prompt: str = "Who are you?"):
100
+ """Generate text using the model"""
101
+ if pipe is None:
102
+ initialize_pipeline()
103
+
104
+ messages = [{"role": "user", "content": prompt}]
105
+ result = pipe(messages, max_new_tokens=100)
106
+ return {"response": result[0]["generated_text"]}
107
+
108
+ @app.post("/chat")
109
+ def chat_completion(messages: list):
110
+ """Chat completion endpoint"""
111
+ if pipe is None:
112
+ initialize_pipeline()
113
+
114
+ result = pipe(messages, max_new_tokens=200)
115
+ return {"response": result[0]["generated_text"]}
116
+
117
+ @app.post("/v1/chat/completions")
118
+ def openai_chat_completions(request: dict):
119
+ """
120
+ OpenAI-compatible chat completions endpoint
121
+ Expected request format:
122
+ {
123
+ "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
124
+ "messages": [
125
+ {"role": "user", "content": "Hello"}
126
+ ],
127
+ "max_tokens": 100,
128
+ "temperature": 0.7
129
+ }
130
+ """
131
+ if pipe is None:
132
+ initialize_pipeline()
133
+
134
+ import time
135
+
136
+ messages = request.get("messages", [])
137
+ model = request.get("model", model_name)
138
+ max_tokens = request.get("max_tokens", 100)
139
+ temperature = request.get("temperature", 0.7)
140
+
141
+ print('\n\n request')
142
+ print(request)
143
+ print('\n\n messages')
144
+ print(messages)
145
+ print('\n\n model')
146
+ print(model)
147
+ print('\n\n max_tokens')
148
+ print(max_tokens)
149
+ print('\n\n temperature')
150
+ print(temperature)
151
+
152
+ # Generate response
153
+ result = pipe(
154
+ messages,
155
+ max_new_tokens=max_tokens,
156
+ # temperature=temperature
157
+ )
158
+ print('asdfasdfasdfasdf')
159
+
160
+ completion_id = f"chatcmpl-{int(time.time())}"
161
+ created = int(time.time())
162
+
163
+ return {
164
+ "id": completion_id,
165
+ "object": "chat.completion",
166
+ "created": created,
167
+ "model": model,
168
+ "choices": [
169
+ {
170
+ "index": 0,
171
+ "message": {
172
+ "role": "assistant",
173
+ "content": result[0]["generated_text"]
174
+ },
175
+ "finish_reason": "stop"
176
  }
177
+ ],
178
+ "usage": {
179
+ "prompt_tokens": 0, # Would need tokenizer to calculate
180
+ "completion_tokens": 0,
181
+ "total_tokens": 0
182
  }
183
+ }
184
+
185
+ # Initialize model on startup
186
+ @app.on_event("startup")
187
+ async def startup_event():
188
+ """Initialize the model when the app starts"""
189
+ print("=" * 60)
190
+ print("FunctionGemma FastAPI Server")
191
+ print("=" * 60)
192
+ print("Initializing model...")
193
+ initialize_pipeline()
194
  print("\n" + "=" * 60)
195
+ print("Server ready at http://0.0.0.0:7860")
196
  print("Available endpoints:")
197
  print(" GET / - Welcome message")
198
  print(" GET /health - Health check")
 
200
  print(" POST /chat - Chat completion")
201
  print(" POST /v1/chat/completions - OpenAI-compatible endpoint")
202
  print("=" * 60 + "\n")