sixfingerdev commited on
Commit
8bd0408
·
verified ·
1 Parent(s): 6db1b6a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -933,12 +933,13 @@ print(tokenizer.decode(outputs[0], skip_special_tokens=True))</pre>
933
  # ==================== MODEL YÜKLEME ====================
934
  print("Modeller yükleniyor... Bu biraz sürebilir (özellikle ilk seferde).")
935
 
 
 
936
  model_stable = AutoModelForCausalLM.from_pretrained(
937
  "sixfingerdev/kayra-1",
938
  trust_remote_code=True,
939
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
940
- device_map="auto"
941
- )
942
  tokenizer_stable = AutoTokenizer.from_pretrained("sixfingerdev/kayra-1")
943
 
944
  model_exp = AutoModelForCausalLM.from_pretrained(
 
933
  # ==================== MODEL YÜKLEME ====================
934
  print("Modeller yükleniyor... Bu biraz sürebilir (özellikle ilk seferde).")
935
 
936
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
937
+
938
  model_stable = AutoModelForCausalLM.from_pretrained(
939
  "sixfingerdev/kayra-1",
940
  trust_remote_code=True,
941
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
942
+ ).to(device)
 
943
  tokenizer_stable = AutoTokenizer.from_pretrained("sixfingerdev/kayra-1")
944
 
945
  model_exp = AutoModelForCausalLM.from_pretrained(