anaspro commited on
Commit
a4a88c0
·
1 Parent(s): 98c7af3
Files changed (1) hide show
  1. app.py +8 -9
app.py CHANGED
@@ -4,22 +4,21 @@ import torch
4
  from PIL import Image
5
  from pathlib import Path
6
  from threading import Thread
7
- from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
8
  import spaces
9
  import time
10
  import os
11
 
 
12
  # model config - Single model: Shako v4
13
  model_name = "anaspro/Shako-4B-it"
14
 
15
- # Load token from environment if available
16
- hf_token = os.getenv("HF_TOKEN")
17
-
18
- model = Gemma3ForConditionalGeneration.from_pretrained(
19
- model_name,
20
  device_map="auto",
21
- torch_dtype=torch.bfloat16,
22
- token=hf_token
23
  ).eval()
24
  processor = AutoProcessor.from_pretrained(model_name, token=hf_token)
25
  # I will add timestamp later
@@ -83,7 +82,7 @@ def format_conversation_history(chat_history):
83
  messages.append({"role": "user", "content": current_user_content})
84
  return messages
85
 
86
- @spaces.GPU(duration=120)
87
  def generate_response(input_data, chat_history, max_new_tokens, system_prompt, temperature, top_p, top_k, repetition_penalty):
88
  if isinstance(input_data, dict) and "text" in input_data:
89
  text = input_data["text"]
 
4
  from PIL import Image
5
  from pathlib import Path
6
  from threading import Thread
7
+ from transformers import AutoProcessor, AutoModelForImageTextToText, TextIteratorStreamer
8
  import spaces
9
  import time
10
  import os
11
 
12
+
13
  # model config - Single model: Shako v4
14
  model_name = "anaspro/Shako-4B-it"
15
 
16
+ model_id = "anaspro/Shako-4B-it"
17
+ processor = AutoProcessor.from_pretrained(model_id)
18
+ model = AutoModelForImageTextToText.from_pretrained(
19
+ model_id,
 
20
  device_map="auto",
21
+ torch_dtype=torch.bfloat16
 
22
  ).eval()
23
  processor = AutoProcessor.from_pretrained(model_name, token=hf_token)
24
  # I will add timestamp later
 
82
  messages.append({"role": "user", "content": current_user_content})
83
  return messages
84
 
85
+ @spaces.GPU() # duration=120
86
  def generate_response(input_data, chat_history, max_new_tokens, system_prompt, temperature, top_p, top_k, repetition_penalty):
87
  if isinstance(input_data, dict) and "text" in input_data:
88
  text = input_data["text"]