Mr-HASSAN commited on
Commit
7278ce3
·
verified ·
1 Parent(s): c98011b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -71
app.py CHANGED
@@ -1,87 +1,58 @@
1
- # app.py - FIXED VERSION
2
  import gradio as gr
3
- import torch
4
- from transformers import AutoModelForCausalLM, AutoTokenizer
5
- import json
6
  import os
7
 
8
- MODEL_ID = "fdtn-ai/Foundation-Sec-8B"
9
-
10
- print("🚀 Loading model...")
11
-
12
- # FIX: Download and patch config first
13
- from huggingface_hub import hf_hub_download
14
-
15
- # Download config
16
- config_path = hf_hub_download(
17
- repo_id=MODEL_ID,
18
- filename="config.json",
19
- local_dir="./cache"
20
- )
21
-
22
- # Read and fix config
23
- with open(config_path, 'r') as f:
24
- config_data = json.load(f)
25
-
26
- # Fix rope_scaling for Llama 3
27
- if 'rope_scaling' in config_data:
28
- rope = config_data['rope_scaling']
29
- if isinstance(rope, dict):
30
- # Convert to standard format
31
- rope_scaling = {
32
- "type": rope.get("rope_type", "linear"),
33
- "factor": rope.get("factor", 1.0)
34
- }
35
- config_data['rope_scaling'] = rope_scaling
36
-
37
- # Save fixed config
38
- os.makedirs("./fixed_config", exist_ok=True)
39
- fixed_config_path = "./fixed_config/config.json"
40
- with open(fixed_config_path, 'w') as f:
41
- json.dump(config_data, f)
42
-
43
- # Load with fixed config
44
- from transformers import AutoConfig
45
- config = AutoConfig.from_pretrained(fixed_config_path)
46
-
47
- # Load tokenizer
48
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
49
 
50
- # Load model
51
  model = AutoModelForCausalLM.from_pretrained(
52
- MODEL_ID,
53
- config=config,
54
  torch_dtype=torch.float16,
55
  device_map="auto",
56
  trust_remote_code=True
57
  )
 
58
 
59
- print("✅ Model loaded!")
60
-
61
- def generate(prompt, max_tokens=200):
62
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
63
 
64
- outputs = model.generate(
65
- **inputs,
66
- max_new_tokens=max_tokens,
67
- temperature=0.7,
68
- do_sample=True
69
- )
70
 
71
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
72
-
73
- # Create interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  gr.Interface(
75
  generate,
76
- [
77
- gr.Textbox(label="Prompt", lines=3),
78
- gr.Slider(50, 500, value=200, label="Max Tokens")
79
- ],
80
  gr.Textbox(label="Response", lines=10),
81
- title="🔒 Foundation-Sec-8B",
82
- examples=[
83
- ["Explain cybersecurity:"],
84
- ["What is a firewall?"],
85
- ["How to create strong passwords?"]
86
- ]
87
  ).launch(server_name="0.0.0.0")
 
1
+ # app.py - LOAD ON DEMAND
2
  import gradio as gr
3
+ import subprocess
4
+ import tempfile
 
5
  import os
6
 
7
+ def generate(prompt):
8
+ """Load model on-demand using transformers CLI"""
9
+ # Create a temporary script
10
+ script = f"""
11
+ import torch
12
+ from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
 
14
  model = AutoModelForCausalLM.from_pretrained(
15
+ "fdtn-ai/Foundation-Sec-8B",
 
16
  torch_dtype=torch.float16,
17
  device_map="auto",
18
  trust_remote_code=True
19
  )
20
+ tokenizer = AutoTokenizer.from_pretrained("fdtn-ai/Foundation-Sec-8B")
21
 
22
+ inputs = tokenizer('{prompt}', return_tensors="pt").to(model.device)
23
+ outputs = model.generate(**inputs, max_new_tokens=200)
24
+ print(tokenizer.decode(outputs[0], skip_special_tokens=True))
25
+ """
26
 
27
+ # Write to temp file
28
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
29
+ f.write(script)
30
+ script_path = f.name
 
 
31
 
32
+ try:
33
+ # Run script
34
+ result = subprocess.run(
35
+ ['python', script_path],
36
+ capture_output=True,
37
+ text=True,
38
+ timeout=120
39
+ )
40
+
41
+ # Cleanup
42
+ os.unlink(script_path)
43
+
44
+ if result.returncode == 0:
45
+ return result.stdout.strip()
46
+ else:
47
+ return f"Error: {result.stderr}"
48
+
49
+ except subprocess.TimeoutExpired:
50
+ return "Timeout - Model loading took too long"
51
+
52
+ # Launch interface
53
  gr.Interface(
54
  generate,
55
+ gr.Textbox(label="Ask about cybersecurity:"),
 
 
 
56
  gr.Textbox(label="Response", lines=10),
57
+ title="Foundation-Sec-8B (On-demand Loading)"
 
 
 
 
 
58
  ).launch(server_name="0.0.0.0")