Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -9,7 +9,7 @@ from tokenizers import Tokenizer
|
|
| 9 |
from huggingface_hub import hf_hub_download
|
| 10 |
import json
|
| 11 |
from abc import ABC, abstractmethod
|
| 12 |
-
from fastapi import FastAPI, HTTPException
|
| 13 |
from fastapi.responses import StreamingResponse
|
| 14 |
from fastapi.middleware.cors import CORSMiddleware
|
| 15 |
from pydantic import BaseModel
|
|
@@ -170,11 +170,12 @@ class SAM1Model(keras.Model):
|
|
| 170 |
|
| 171 |
self.embed = keras.layers.Embedding(self.cfg['vocab_size'], self.cfg['d_model'], name="embed_tokens")
|
| 172 |
|
|
|
|
| 173 |
ff_dim = int(self.cfg['d_model'] * self.cfg['ff_mult'])
|
| 174 |
block_args = {
|
| 175 |
'd_model': self.cfg['d_model'],
|
| 176 |
'n_heads': self.cfg['n_heads'],
|
| 177 |
-
'ff_dim':
|
| 178 |
'dropout': self.cfg['dropout'],
|
| 179 |
'max_len': self.cfg['max_len'],
|
| 180 |
'rope_theta': self.cfg['rope_theta']
|
|
@@ -201,7 +202,7 @@ class SAM1Model(keras.Model):
|
|
| 201 |
|
| 202 |
|
| 203 |
# ==============================================================================
|
| 204 |
-
# Helper
|
| 205 |
# ==============================================================================
|
| 206 |
|
| 207 |
def count_parameters(model):
|
|
@@ -399,7 +400,7 @@ async def generate_stream(prompt: str, backend, temperature: float) -> AsyncGene
|
|
| 399 |
|
| 400 |
def chat_fn(message, history, model_choice="SAM-X-1-Large", temperature=0.7):
|
| 401 |
backend = available_models[model_choice]
|
| 402 |
-
prompt = f"User: {message}\nSam:
|
| 403 |
response = ""
|
| 404 |
for chunk in generate_stream(prompt, backend, temperature):
|
| 405 |
response += chunk
|
|
@@ -442,7 +443,7 @@ async def chat_completions(request: ChatCompletionRequest):
|
|
| 442 |
for msg in request.messages:
|
| 443 |
prefix = "User" if msg.role.lower() == "user" else "Sam"
|
| 444 |
prompt_parts.append(f"{prefix}: {msg.content}")
|
| 445 |
-
prompt_parts.append("Sam:
|
| 446 |
prompt = "\n".join(prompt_parts)
|
| 447 |
|
| 448 |
async def event_stream():
|
|
@@ -466,11 +467,12 @@ async def chat_completions(request: ChatCompletionRequest):
|
|
| 466 |
else:
|
| 467 |
full = ""
|
| 468 |
async for token in event_stream():
|
| 469 |
-
if
|
| 470 |
data = json.loads(token.replace("data: ", "").strip())
|
| 471 |
full += data["choices"][0]["delta"]["content"]
|
| 472 |
return {"choices": [{"message": {"content": full}}]}
|
| 473 |
|
|
|
|
| 474 |
@app.get("/v1/models")
|
| 475 |
async def list_models():
|
| 476 |
return {
|
|
@@ -506,5 +508,5 @@ with gr.Blocks(title="SAM-X-1 Chat", theme=gr.themes.Soft()) as demo:
|
|
| 506 |
]
|
| 507 |
)
|
| 508 |
|
| 509 |
-
# Mount Gradio app on root
|
| 510 |
app = gr.mount_gradio_app(app, demo, path="/")
|
|
|
|
| 9 |
from huggingface_hub import hf_hub_download
|
| 10 |
import json
|
| 11 |
from abc import ABC, abstractmethod
|
| 12 |
+
from fastapi import FastAPI, HTTPException
|
| 13 |
from fastapi.responses import StreamingResponse
|
| 14 |
from fastapi.middleware.cors import CORSMiddleware
|
| 15 |
from pydantic import BaseModel
|
|
|
|
| 170 |
|
| 171 |
self.embed = keras.layers.Embedding(self.cfg['vocab_size'], self.cfg['d_model'], name="embed_tokens")
|
| 172 |
|
| 173 |
+
# ✅ FIXED: Was using 'ff_num' — now correctly uses 'ff_dim'
|
| 174 |
ff_dim = int(self.cfg['d_model'] * self.cfg['ff_mult'])
|
| 175 |
block_args = {
|
| 176 |
'd_model': self.cfg['d_model'],
|
| 177 |
'n_heads': self.cfg['n_heads'],
|
| 178 |
+
'ff_dim': ff_dim, # ✅ Correct variable name
|
| 179 |
'dropout': self.cfg['dropout'],
|
| 180 |
'max_len': self.cfg['max_len'],
|
| 181 |
'rope_theta': self.cfg['rope_theta']
|
|
|
|
| 202 |
|
| 203 |
|
| 204 |
# ==============================================================================
|
| 205 |
+
# Helper Functions
|
| 206 |
# ==============================================================================
|
| 207 |
|
| 208 |
def count_parameters(model):
|
|
|
|
| 400 |
|
| 401 |
def chat_fn(message, history, model_choice="SAM-X-1-Large", temperature=0.7):
|
| 402 |
backend = available_models[model_choice]
|
| 403 |
+
prompt = f"User: {message}\nSam: <think>"
|
| 404 |
response = ""
|
| 405 |
for chunk in generate_stream(prompt, backend, temperature):
|
| 406 |
response += chunk
|
|
|
|
| 443 |
for msg in request.messages:
|
| 444 |
prefix = "User" if msg.role.lower() == "user" else "Sam"
|
| 445 |
prompt_parts.append(f"{prefix}: {msg.content}")
|
| 446 |
+
prompt_parts.append("Sam: <think>")
|
| 447 |
prompt = "\n".join(prompt_parts)
|
| 448 |
|
| 449 |
async def event_stream():
|
|
|
|
| 467 |
else:
|
| 468 |
full = ""
|
| 469 |
async for token in event_stream():
|
| 470 |
+
if "[DONE]" not in token:
|
| 471 |
data = json.loads(token.replace("data: ", "").strip())
|
| 472 |
full += data["choices"][0]["delta"]["content"]
|
| 473 |
return {"choices": [{"message": {"content": full}}]}
|
| 474 |
|
| 475 |
+
|
| 476 |
@app.get("/v1/models")
|
| 477 |
async def list_models():
|
| 478 |
return {
|
|
|
|
| 508 |
]
|
| 509 |
)
|
| 510 |
|
| 511 |
+
# Mount Gradio app on root path
|
| 512 |
app = gr.mount_gradio_app(app, demo, path="/")
|