Bc-AI commited on
Commit
1df8ec7
·
verified ·
1 Parent(s): 4ebabd8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -7
app.py CHANGED
@@ -9,7 +9,7 @@ from tokenizers import Tokenizer
9
  from huggingface_hub import hf_hub_download
10
  import json
11
  from abc import ABC, abstractmethod
12
- from fastapi import FastAPI, HTTPException, Request
13
  from fastapi.responses import StreamingResponse
14
  from fastapi.middleware.cors import CORSMiddleware
15
  from pydantic import BaseModel
@@ -170,11 +170,12 @@ class SAM1Model(keras.Model):
170
 
171
  self.embed = keras.layers.Embedding(self.cfg['vocab_size'], self.cfg['d_model'], name="embed_tokens")
172
 
 
173
  ff_dim = int(self.cfg['d_model'] * self.cfg['ff_mult'])
174
  block_args = {
175
  'd_model': self.cfg['d_model'],
176
  'n_heads': self.cfg['n_heads'],
177
- 'ff_dim': ff_num,
178
  'dropout': self.cfg['dropout'],
179
  'max_len': self.cfg['max_len'],
180
  'rope_theta': self.cfg['rope_theta']
@@ -201,7 +202,7 @@ class SAM1Model(keras.Model):
201
 
202
 
203
  # ==============================================================================
204
- # Helper: Parameter Counting
205
  # ==============================================================================
206
 
207
  def count_parameters(model):
@@ -399,7 +400,7 @@ async def generate_stream(prompt: str, backend, temperature: float) -> AsyncGene
399
 
400
  def chat_fn(message, history, model_choice="SAM-X-1-Large", temperature=0.7):
401
  backend = available_models[model_choice]
402
- prompt = f"User: {message}\nSam: <think>"
403
  response = ""
404
  for chunk in generate_stream(prompt, backend, temperature):
405
  response += chunk
@@ -442,7 +443,7 @@ async def chat_completions(request: ChatCompletionRequest):
442
  for msg in request.messages:
443
  prefix = "User" if msg.role.lower() == "user" else "Sam"
444
  prompt_parts.append(f"{prefix}: {msg.content}")
445
- prompt_parts.append("Sam: <think>")
446
  prompt = "\n".join(prompt_parts)
447
 
448
  async def event_stream():
@@ -466,11 +467,12 @@ async def chat_completions(request: ChatCompletionRequest):
466
  else:
467
  full = ""
468
  async for token in event_stream():
469
- if b"[DONE]" not in token.encode():
470
  data = json.loads(token.replace("data: ", "").strip())
471
  full += data["choices"][0]["delta"]["content"]
472
  return {"choices": [{"message": {"content": full}}]}
473
 
 
474
  @app.get("/v1/models")
475
  async def list_models():
476
  return {
@@ -506,5 +508,5 @@ with gr.Blocks(title="SAM-X-1 Chat", theme=gr.themes.Soft()) as demo:
506
  ]
507
  )
508
 
509
- # Mount Gradio app on root
510
  app = gr.mount_gradio_app(app, demo, path="/")
 
9
  from huggingface_hub import hf_hub_download
10
  import json
11
  from abc import ABC, abstractmethod
12
+ from fastapi import FastAPI, HTTPException
13
  from fastapi.responses import StreamingResponse
14
  from fastapi.middleware.cors import CORSMiddleware
15
  from pydantic import BaseModel
 
170
 
171
  self.embed = keras.layers.Embedding(self.cfg['vocab_size'], self.cfg['d_model'], name="embed_tokens")
172
 
173
+ # ✅ FIXED: Was using 'ff_num' — now correctly uses 'ff_dim'
174
  ff_dim = int(self.cfg['d_model'] * self.cfg['ff_mult'])
175
  block_args = {
176
  'd_model': self.cfg['d_model'],
177
  'n_heads': self.cfg['n_heads'],
178
+ 'ff_dim': ff_dim, # ✅ Correct variable name
179
  'dropout': self.cfg['dropout'],
180
  'max_len': self.cfg['max_len'],
181
  'rope_theta': self.cfg['rope_theta']
 
202
 
203
 
204
  # ==============================================================================
205
+ # Helper Functions
206
  # ==============================================================================
207
 
208
  def count_parameters(model):
 
400
 
401
  def chat_fn(message, history, model_choice="SAM-X-1-Large", temperature=0.7):
402
  backend = available_models[model_choice]
403
+ prompt = f"User: {message}\nSam: <think>"
404
  response = ""
405
  for chunk in generate_stream(prompt, backend, temperature):
406
  response += chunk
 
443
  for msg in request.messages:
444
  prefix = "User" if msg.role.lower() == "user" else "Sam"
445
  prompt_parts.append(f"{prefix}: {msg.content}")
446
+ prompt_parts.append("Sam: <think>")
447
  prompt = "\n".join(prompt_parts)
448
 
449
  async def event_stream():
 
467
  else:
468
  full = ""
469
  async for token in event_stream():
470
+ if "[DONE]" not in token:
471
  data = json.loads(token.replace("data: ", "").strip())
472
  full += data["choices"][0]["delta"]["content"]
473
  return {"choices": [{"message": {"content": full}}]}
474
 
475
+
476
  @app.get("/v1/models")
477
  async def list_models():
478
  return {
 
508
  ]
509
  )
510
 
511
+ # Mount Gradio app on root path
512
  app = gr.mount_gradio_app(app, demo, path="/")