FractalAIR commited on
Commit
3b68d21
·
verified ·
1 Parent(s): c34c8d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -22
app.py CHANGED
@@ -8,36 +8,29 @@ MODEL_ID = "FractalAIResearch/Fathom-R1-14B"
8
  @spaces.GPU
9
  def chat_with_model(message, history, max_tokens, temperature):
10
  try:
11
- print("🔥 GPU allocated, loading model...")
12
 
13
- # Load model and tokenizer
14
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
15
  model = AutoModelForCausalLM.from_pretrained(
16
  MODEL_ID,
17
  torch_dtype=torch.bfloat16,
18
- trust_remote_code=True
 
19
  )
20
 
21
- # EXPLICITLY move model to GPU
22
- model = model.cuda()
23
-
24
- print(f"✅ Model loaded on device: {model.device}")
25
- print(f"🔥 GPU available: {torch.cuda.is_available()}")
26
- print(f"🔥 GPU device count: {torch.cuda.device_count()}")
27
-
28
  if tokenizer.pad_token is None:
29
  tokenizer.pad_token = tokenizer.eos_token
30
 
 
 
31
  # Simple prompt format
32
  prompt = f"User: {message}\nAssistant:"
33
 
34
- # Tokenize and move to GPU
35
  inputs = tokenizer(prompt, return_tensors="pt")
36
- inputs = {k: v.cuda() for k, v in inputs.items()}
37
 
38
- print(f"✅ Inputs moved to: {inputs['input_ids'].device}")
39
-
40
- # Generate
41
  with torch.no_grad():
42
  outputs = model.generate(
43
  **inputs,
@@ -51,15 +44,15 @@ def chat_with_model(message, history, max_tokens, temperature):
51
  # Decode response
52
  response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[-1]:], skip_special_tokens=True)
53
 
54
- print(f"✅ Generated response: {response[:100]}...")
55
-
56
  # Update history
57
  history.append([message, response])
58
  return history, history, ""
59
 
60
  except Exception as e:
61
- error_msg = f"Error: {str(e)}"
62
- print(error_msg)
 
 
63
  history.append([message, error_msg])
64
  return history, history, ""
65
 
@@ -86,8 +79,8 @@ with gr.Blocks(title="Fathom R1 14B Chatbot") as demo:
86
  gr.Markdown("### Settings")
87
  max_tokens = gr.Slider(
88
  minimum=50,
89
- maximum=2048,
90
- value=512,
91
  step=50,
92
  label="Max Tokens"
93
  )
@@ -103,7 +96,7 @@ with gr.Blocks(title="Fathom R1 14B Chatbot") as demo:
103
  gr.Examples(
104
  examples=[
105
  "Solve: 2x + 5 = 15",
106
- "Explain quantum mechanics simply",
107
  "What is the derivative of x²?",
108
  ],
109
  inputs=msg
 
8
  @spaces.GPU
9
  def chat_with_model(message, history, max_tokens, temperature):
10
  try:
11
+ print("Loading model...")
12
 
13
+ # Load model and tokenizer - let ZeroGPU handle device placement
14
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
15
  model = AutoModelForCausalLM.from_pretrained(
16
  MODEL_ID,
17
  torch_dtype=torch.bfloat16,
18
+ trust_remote_code=True,
19
+ device_map="auto" # Let transformers handle GPU placement
20
  )
21
 
 
 
 
 
 
 
 
22
  if tokenizer.pad_token is None:
23
  tokenizer.pad_token = tokenizer.eos_token
24
 
25
+ print(f"Model loaded successfully on device: {next(model.parameters()).device}")
26
+
27
  # Simple prompt format
28
  prompt = f"User: {message}\nAssistant:"
29
 
30
+ # Tokenize - let the model handle device placement
31
  inputs = tokenizer(prompt, return_tensors="pt")
 
32
 
33
+ # Generate - the model will automatically handle device placement
 
 
34
  with torch.no_grad():
35
  outputs = model.generate(
36
  **inputs,
 
44
  # Decode response
45
  response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[-1]:], skip_special_tokens=True)
46
 
 
 
47
  # Update history
48
  history.append([message, response])
49
  return history, history, ""
50
 
51
  except Exception as e:
52
+ error_msg = f"Error: {str(e)}"
53
+ print(f"Full error: {e}")
54
+ import traceback
55
+ traceback.print_exc()
56
  history.append([message, error_msg])
57
  return history, history, ""
58
 
 
79
  gr.Markdown("### Settings")
80
  max_tokens = gr.Slider(
81
  minimum=50,
82
+ maximum=1024,
83
+ value=256,
84
  step=50,
85
  label="Max Tokens"
86
  )
 
96
  gr.Examples(
97
  examples=[
98
  "Solve: 2x + 5 = 15",
99
+ "Explain quantum mechanics simply",
100
  "What is the derivative of x²?",
101
  ],
102
  inputs=msg