rasAli02 commited on
Commit
307f1c5
Β·
1 Parent(s): a0c4b2d

git add, commit, push

Browse files
backend/agents.py CHANGED
@@ -14,6 +14,14 @@ from typing import Optional, List, Dict, Any
14
 
15
  EMERGENT_LLM_KEY = os.environ.get("EMERGENT_LLM_KEY", "")
16
 
 
 
 
 
 
 
 
 
17
  # Model choices β€” Claude Sonnet 4.5 is vision-capable and strong for reasoning.
18
  VISION_MODEL = ("anthropic", "claude-sonnet-4-5-20250929")
19
  TEXT_MODEL = ("anthropic", "claude-sonnet-4-5-20250929")
@@ -91,6 +99,66 @@ def _extract_json(raw: str) -> Dict[str, Any]:
91
  return {"_raw": raw}
92
 
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  async def _run_agent(
95
  name: str,
96
  system_message: str,
@@ -98,47 +166,26 @@ async def _run_agent(
98
  image_base64: Optional[str] = None,
99
  provider_model: tuple = TEXT_MODEL,
100
  ) -> Dict[str, Any]:
101
- # MOCKED for preview
 
 
 
 
102
  import asyncio
103
- await asyncio.sleep(0.5)
104
-
105
- parsed = {}
106
- if name == "inspector":
107
- parsed = {
108
- "verdict": "warn",
109
- "confidence": 0.85,
110
- "defects": [
111
- {"type": "surface-scratch", "severity": "low", "location": "top-left edge", "description": "Minor scratch visible"}
112
- ],
113
- "observation": "I can see a minor scratch on the surface."
114
- }
115
- elif name == "diagnostician":
116
- parsed = {
117
- "probable_cause": "Improper handling during milling.",
118
- "contributing_factors": ["Machine calibration", "Operator error"],
119
- "affected_process_step": "CNC milling"
120
- }
121
- elif name == "action":
122
- parsed = {
123
- "priority": "P2",
124
- "assignee_role": "quality-engineer",
125
- "steps": ["Inspect machine", "Recalibrate"],
126
- "estimated_minutes": 30,
127
- "parts_or_tools": ["Calibration kit"]
128
- }
129
- elif name == "reporter":
130
- parsed = {
131
- "headline": "Minor Scratch Detected",
132
- "summary": "A minor scratch was detected during the milling process.",
133
- "tags": ["scratch", "milling"]
134
- }
135
- elif name == "social":
136
- parsed = {
137
- "x_post": "Testing our pipeline #AMDHackathon",
138
- "linkedin_post": "We are testing our pipeline today..."
139
- }
140
-
141
- return {"raw": json.dumps(parsed), "parsed": parsed}
142
 
143
 
144
  async def run_pipeline(
 
14
 
15
  EMERGENT_LLM_KEY = os.environ.get("EMERGENT_LLM_KEY", "")
16
 
17
+ # AMD MI300X inference server (fine-tuned compliance model)
18
+ # Jupyter proxy route used since direct port 8000 is firewalled.
19
+ # Override with AMD_INFERENCE_URL env var if direct access is available.
20
+ AMD_INFERENCE_URL = os.environ.get(
21
+ "AMD_INFERENCE_URL",
22
+ "http://129.212.191.163/proxy/8000"
23
+ )
24
+
25
  # Model choices β€” Claude Sonnet 4.5 is vision-capable and strong for reasoning.
26
  VISION_MODEL = ("anthropic", "claude-sonnet-4-5-20250929")
27
  TEXT_MODEL = ("anthropic", "claude-sonnet-4-5-20250929")
 
99
  return {"_raw": raw}
100
 
101
 
102
+ def _build_prompt(system_message: str, user_text: str) -> str:
103
+ return f"<|system|>{system_message}<|user|>{user_text}<|assistant|>"
104
+
105
+
106
+ def _mock_response(name: str) -> Dict[str, Any]:
107
+ """Fallback mock responses for local development (AMD server not running)."""
108
+ mocks = {
109
+ "inspector": {
110
+ "verdict": "warn", "confidence": 0.85,
111
+ "defects": [{"type": "surface-scratch", "severity": "low", "location": "top-left edge", "description": "Minor scratch visible"}],
112
+ "observation": "Minor scratch detected on surface. [LOCAL MOCK β€” AMD server offline]"
113
+ },
114
+ "diagnostician": {
115
+ "probable_cause": "Improper handling during milling. [LOCAL MOCK]",
116
+ "contributing_factors": ["Machine calibration", "Operator error"],
117
+ "affected_process_step": "CNC milling"
118
+ },
119
+ "action": {
120
+ "priority": "P2", "assignee_role": "quality-engineer",
121
+ "steps": ["Inspect machine", "Recalibrate"],
122
+ "estimated_minutes": 30, "parts_or_tools": ["Calibration kit"]
123
+ },
124
+ "reporter": {
125
+ "headline": "Minor Scratch Detected [Mock]",
126
+ "summary": "Local mock response β€” start the AMD inference server to use the fine-tuned compliance model.",
127
+ "tags": ["scratch", "mock", "local"]
128
+ },
129
+ "social": {
130
+ "x_post": "Testing our pipeline #AMDHackathon",
131
+ "linkedin_post": "We are testing our pipeline today..."
132
+ },
133
+ }
134
+ parsed = mocks.get(name, {})
135
+ return {"raw": json.dumps(parsed), "parsed": parsed, "source": "mock"}
136
+
137
+
138
+ async def _call_amd_server(prompt: str) -> Optional[str]:
139
+ """Call the fine-tuned model running on AMD MI300X. Returns None if unreachable."""
140
+ import asyncio
141
+ import urllib.request
142
+ import urllib.error
143
+
144
+ payload = json.dumps({"prompt": prompt, "max_tokens": 512}).encode()
145
+ req = urllib.request.Request(
146
+ f"{AMD_INFERENCE_URL}/v1/complete",
147
+ data=payload,
148
+ headers={"Content-Type": "application/json"},
149
+ method="POST",
150
+ )
151
+ try:
152
+ loop = asyncio.get_event_loop()
153
+ def _do_request():
154
+ with urllib.request.urlopen(req, timeout=10) as resp:
155
+ return json.loads(resp.read())
156
+ result = await loop.run_in_executor(None, _do_request)
157
+ return result.get("text", "")
158
+ except Exception:
159
+ return None # Server offline β€” caller will use mock
160
+
161
+
162
  async def _run_agent(
163
  name: str,
164
  system_message: str,
 
166
  image_base64: Optional[str] = None,
167
  provider_model: tuple = TEXT_MODEL,
168
  ) -> Dict[str, Any]:
169
+ """
170
+ Run an agent. Tries the AMD MI300X fine-tuned model first.
171
+ Falls back to mock responses automatically if the server is not running
172
+ (e.g. local development without the AMD instance active).
173
+ """
174
  import asyncio
175
+ await asyncio.sleep(0.1)
176
+
177
+ prompt = _build_prompt(system_message, user_text)
178
+ raw_text = await _call_amd_server(prompt)
179
+
180
+ if raw_text is None:
181
+ # AMD server not reachable β€” use local mock (safe for dev)
182
+ result = _mock_response(name)
183
+ result["source"] = "mock (AMD server offline)"
184
+ return result
185
+
186
+ # AMD server responded β€” parse its JSON output
187
+ parsed = _extract_json(raw_text)
188
+ return {"raw": raw_text, "parsed": parsed, "source": f"AMD MI300X @ {AMD_INFERENCE_URL}"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
 
191
  async def run_pipeline(
backend/amd_hackathon/fine-tune.yaml CHANGED
@@ -1,25 +1,25 @@
1
- base_model: Qwen/Qwen2.5-32B-Instruct
2
  model_type: AutoModelForCausalLM
3
  tokenizer_type: AutoTokenizer
4
 
5
- # Target Domain: Technical Engineering Specs (Track 2 Focus)
6
  load_in_4bit: true
7
  strict: false
8
 
9
  datasets:
10
- - path: data/engineering_specs_synthetic.jsonl
11
  type: chat_template
12
  chat_template: chatml
13
 
14
  dataset_prepared_path: last_run_prepared
15
  val_set_size: 0.05
16
- output_dir: ./qwen2.5-32b-engineering-lora
17
 
18
  # QLoRA Adapter Configuration
19
  adapter: qlora
20
  lora_model_dir:
21
 
22
- sequence_len: 8192 # Increased sequence length for complex specs
23
  sample_packing: true
24
  eval_sample_packing: false
25
  pad_to_sequence_len: true
@@ -30,7 +30,7 @@ lora_dropout: 0.05
30
  lora_target_linear: true
31
  lora_fan_in_fan_out:
32
 
33
- wandb_project: rocm-qwen32b-engineering
34
  wandb_entity:
35
  wandb_watch:
36
  wandb_name: mi300x-run-1
 
1
+ base_model: Qwen/Qwen3.6-35B
2
  model_type: AutoModelForCausalLM
3
  tokenizer_type: AutoTokenizer
4
 
5
+ # Target Domain: Infrastructure Compliance
6
  load_in_4bit: true
7
  strict: false
8
 
9
  datasets:
10
+ - path: data/infrastructure_compliance.jsonl
11
  type: chat_template
12
  chat_template: chatml
13
 
14
  dataset_prepared_path: last_run_prepared
15
  val_set_size: 0.05
16
+ output_dir: ./qwen3.6-35b-compliance-lora
17
 
18
  # QLoRA Adapter Configuration
19
  adapter: qlora
20
  lora_model_dir:
21
 
22
+ sequence_len: 32768 # 32k context window for long technical documents
23
  sample_packing: true
24
  eval_sample_packing: false
25
  pad_to_sequence_len: true
 
30
  lora_target_linear: true
31
  lora_fan_in_fan_out:
32
 
33
+ wandb_project: rocm-qwen35b-compliance
34
  wandb_entity:
35
  wandb_watch:
36
  wandb_name: mi300x-run-1
backend/amd_hackathon/generate_dataset.py CHANGED
@@ -5,35 +5,35 @@ from pathlib import Path
5
 
6
  # We will save the data to the 'data' directory relative to where the script is run
7
  OUTPUT_DIR = Path("data")
8
- OUTPUT_FILE = OUTPUT_DIR / "engineering_specs_synthetic.jsonl"
9
 
10
- # Sample synthetic data templates focused on Technical Engineering Specs
11
- SYSTEM_PROMPT = "You are an expert industrial diagnostic AI. Given an engineering defect report or visual observation, you must reference technical specifications and provide a highly accurate diagnostic verdict, severity, and remediation steps."
 
12
 
13
  SYNTHETIC_SAMPLES = [
14
  {
15
- "user": "I am looking at a steel beam (Grade A36) with visible surface micro-fractures near the weld joint. The load bearing requirement is 50 kN.",
16
- "assistant": "According to structural engineering standard AWS D1.1, micro-fractures near a weld joint on ASTM A36 steel indicate potential hydrogen-induced cracking or rapid cooling stress. Given the 50 kN load requirement, this poses a critical structural risk. \n\n**Verdict**: FAIL\n**Severity**: CRITICAL\n**Remediation**: Immediate halt of load bearing. Recommend ultrasonic testing (UT) to determine crack depth and subsequent re-welding using low-hydrogen electrodes."
17
  },
18
  {
19
- "user": "The printed circuit board (PCB) shows solder bridging between pins 4 and 5 of the main microcontroller (STM32).",
20
- "assistant": "Solder bridging between pins 4 and 5 constitutes a short circuit, violating IPC-A-610 Class 3 standards for high-reliability electronic assemblies. \n\n**Verdict**: FAIL\n**Severity**: HIGH\n**Remediation**: Rework required. Apply flux and use a desoldering braid or hot air rework station to remove the excess solder. Inspect adjacent pins for thermal damage."
21
  },
22
  {
23
- "user": "Visual inspection of the hydraulic turbine blade shows a 2mm pitting corrosion spot on the leading edge. Material is Titanium alloy.",
24
- "assistant": "Pitting corrosion of 2mm on the leading edge of a Titanium alloy hydraulic turbine blade exceeds the operational tolerance outlined in ISO 11366. It disrupts hydrodynamic efficiency and can lead to cavitation erosion over time.\n\n**Verdict**: WARN\n**Severity**: MEDIUM\n**Remediation**: Flag for scheduled maintenance. The pitting should be ground smooth and re-coated with an anti-cavitation protective layer during the next service cycle."
25
  },
26
  {
27
- "user": "Engine gasket seal appears slightly discolored but maintains full contact and no pressure leak detected during the 100psi test.",
28
- "assistant": "Discoloration of the gasket without pressure loss at 100psi is within acceptable limits for standard thermal wear per manufacturer baseline specs. No functional impairment is detected.\n\n**Verdict**: PASS\n**Severity**: LOW\n**Remediation**: No immediate action required. Continue standard monitoring."
29
  }
30
  ]
31
 
32
  def generate_dataset(num_samples: int = 100):
33
  """
34
- Generates a synthetic JSONL dataset formatted in ChatML.
35
- In a real scenario, you could use an LLM API (e.g. GPT-4 or Claude 3.5 Sonnet)
36
- in a loop here to generate thousands of diverse examples.
37
  """
38
  OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
39
 
@@ -41,8 +41,8 @@ def generate_dataset(num_samples: int = 100):
41
 
42
  with open(OUTPUT_FILE, "w", encoding="utf-8") as f:
43
  for _ in range(num_samples):
44
- # For demonstration, we just randomly sample from our templates
45
- # A real generator would use an LLM to generate variations
46
  sample = random.choice(SYNTHETIC_SAMPLES)
47
 
48
  chatml_format = {
 
5
 
6
  # We will save the data to the 'data' directory relative to where the script is run
7
  OUTPUT_DIR = Path("data")
8
+ OUTPUT_FILE = OUTPUT_DIR / "infrastructure_compliance.jsonl"
9
 
10
+ # Synthetic data templates focused on Infrastructure Standards (Eurocodes, ASTM, ISO 9001)
11
+ # and 'Violation Reasoning'
12
+ SYSTEM_PROMPT = "You are an expert infrastructure compliance auditor. Given a technical description of a construction element, compare it against relevant engineering standards (Eurocodes, ASTM Standards, ISO 9001) to identify any violations. You must output the label 'Compliant' or 'Non-Compliant' followed by a detailed reasoning trace."
13
 
14
  SYNTHETIC_SAMPLES = [
15
  {
16
+ "user": "A bridge pier is constructed using concrete with a compressive strength of 25 MPa. It is exposed to a severe marine environment (XS3 exposure class).",
17
+ "assistant": "**Label**: Non-Compliant\n\n**Reasoning Trace**: According to Eurocode 2 (EN 1992-1-1), structures in exposure class XS3 (tidal, splash, and spray zones in marine environments) require a minimum concrete strength class of C35/45 to ensure adequate durability and resistance to chloride-induced corrosion. The provided compressive strength of 25 MPa (C25/30) is insufficient for this exposure class, violating the structural code requirements."
18
  },
19
  {
20
+ "user": "The quality management system documentation includes a calibration record for the torque wrenches used on the primary structural steel connections. The calibration was performed 6 months ago, and the required interval is 12 months.",
21
+ "assistant": "**Label**: Compliant\n\n**Reasoning Trace**: Under ISO 9001:2015 Clause 7.1.5 (Monitoring and measuring resources), organizations must retain appropriate documented information as evidence of fitness for purpose of monitoring and measurement resources. The torque wrenches are calibrated within the required 12-month interval (performed 6 months ago), satisfying the quality management standard for infrastructure construction."
22
  },
23
  {
24
+ "user": "Structural steel beams are connected using high-strength structural bolts (ASTM A325). The bolts are installed using the snug-tightened method for a slip-critical connection.",
25
+ "assistant": "**Label**: Non-Compliant\n\n**Reasoning Trace**: In accordance with the AISC Specification and ASTM standards for structural joints using high-strength bolts, slip-critical connections require bolts to be fully tensioned (e.g., using turn-of-nut, calibrated wrench, or direct tension indicators). The snug-tightened method is only permitted for bearing-type connections where slip is acceptable. Using snug-tightened bolts for a slip-critical connection violates the installation standard and poses a risk of joint failure under service loads."
26
  },
27
  {
28
+ "user": "Reinforcing steel (rebar) for the foundation mat is specified as ASTM A615 Grade 60. The rebar is delivered with mill test reports confirming a minimum yield strength of 60,000 psi.",
29
+ "assistant": "**Label**: Compliant\n\n**Reasoning Trace**: ASTM A615 Grade 60 requires deformed carbon-steel bars for concrete reinforcement to have a minimum yield strength of 60,000 psi. The mill test reports confirm that the delivered material meets this mechanical property requirement, ensuring it complies with the specified ASTM standard for the foundation mat."
30
  }
31
  ]
32
 
33
  def generate_dataset(num_samples: int = 100):
34
  """
35
+ Generates a synthetic JSONL dataset formatted in ChatML for Axolotl fine-tuning.
36
+ Focuses on 'Violation Reasoning' for Infrastructure Compliance.
 
37
  """
38
  OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
39
 
 
41
 
42
  with open(OUTPUT_FILE, "w", encoding="utf-8") as f:
43
  for _ in range(num_samples):
44
+ # For demonstration, we randomly sample from our templates.
45
+ # In production, an LLM pipeline could generate varied scenarios.
46
  sample = random.choice(SYNTHETIC_SAMPLES)
47
 
48
  chatml_format = {
backend/amd_hackathon/track3_vision.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
3
+ from PIL import Image
4
+
5
+ def analyze_construction_site(image_path: str, device: str = "cuda") -> str:
6
+ """
7
+ Uses Qwen2-VL (Track 3) to process a construction site image (e.g., from a drone)
8
+ and output a structured technical description. This description acts as the 'Context'
9
+ for the fine-tuned Track 2 Compliance Auditor model.
10
+ """
11
+ # Initialize the model and processor
12
+ # We use a placeholder path for the Qwen2-VL model here.
13
+ model_id = "Qwen/Qwen2-VL-7B-Instruct"
14
+
15
+ print(f"Loading {model_id} on {device}...")
16
+ try:
17
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
18
+ model_id,
19
+ torch_dtype=torch.bfloat16,
20
+ device_map=device
21
+ )
22
+ processor = AutoProcessor.from_pretrained(model_id)
23
+ except Exception as e:
24
+ print(f"Model loading failed (this is expected if weights aren't downloaded): {e}")
25
+ # Return a mocked structured output for demonstration purposes in the hackathon
26
+ return _mocked_vision_output()
27
+
28
+ # Load the image
29
+ try:
30
+ image = Image.open(image_path).convert("RGB")
31
+ except Exception as e:
32
+ raise ValueError(f"Could not load image at {image_path}: {e}")
33
+
34
+ # Prepare the prompt tailored for technical extraction
35
+ prompt = (
36
+ "You are an expert construction site inspector. Describe the structural elements, "
37
+ "materials, and construction practices visible in this image. Focus on technical "
38
+ "details like concrete pouring, rebar placement, structural steel connections, "
39
+ "and any visible environmental exposure factors. Be highly descriptive and objective."
40
+ )
41
+
42
+ # Qwen2-VL format
43
+ messages = [
44
+ {
45
+ "role": "user",
46
+ "content": [
47
+ {"type": "image", "image": image},
48
+ {"type": "text", "text": prompt},
49
+ ],
50
+ }
51
+ ]
52
+
53
+ # Preprocess inputs
54
+ text = processor.apply_chat_template(
55
+ messages, tokenize=False, add_generation_prompt=True
56
+ )
57
+ image_inputs, video_inputs = processor.image_processor(image), None
58
+ inputs = processor(
59
+ text=[text],
60
+ images=image_inputs,
61
+ videos=video_inputs,
62
+ padding=True,
63
+ return_tensors="pt",
64
+ )
65
+ inputs = inputs.to(device)
66
+
67
+ # Generate output
68
+ print("Analyzing image...")
69
+ with torch.no_grad():
70
+ generated_ids = model.generate(**inputs, max_new_tokens=256)
71
+
72
+ generated_ids_trimmed = [
73
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
74
+ ]
75
+
76
+ output_text = processor.batch_decode(
77
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
78
+ )[0]
79
+
80
+ return _format_for_track2(output_text)
81
+
82
+ def _mocked_vision_output() -> str:
83
+ """Provides a mocked output when running without the heavy VLM weights."""
84
+ mocked_description = (
85
+ "A bridge pier is constructed using concrete. Reinforcement bars are visible with approximately "
86
+ "50mm of concrete cover. The pier is located directly in a tidal splash zone (marine environment). "
87
+ "Concrete surface appears to have minor honeycombing at the base."
88
+ )
89
+ return _format_for_track2(mocked_description)
90
+
91
+ def _format_for_track2(vision_text: str) -> str:
92
+ """
93
+ Structures the vision output so it can be seamlessly passed
94
+ as input 'Context' to the fine-tuned 35B model.
95
+ """
96
+ structured_context = (
97
+ "### VISUAL INSPECTION REPORT (TRACK 3)\n"
98
+ f"{vision_text}\n\n"
99
+ "### TASK\n"
100
+ "Based on the visual inspection report above, identify any violations of structural codes "
101
+ "(e.g., Eurocodes, ASTM, ISO 9001). Provide a label of 'Compliant' or 'Non-Compliant' "
102
+ "followed by a detailed reasoning trace."
103
+ )
104
+ return structured_context
105
+
106
+ if __name__ == "__main__":
107
+ # Test the pipeline
108
+ test_image = "dummy_construction_site.jpg"
109
+ print(f"Testing Multimodal Pipeline with {test_image}")
110
+ try:
111
+ context_for_track2 = analyze_construction_site(test_image)
112
+ print("\n--- Structured Output for Track 2 Model ---\n")
113
+ print(context_for_track2)
114
+ print("\n-------------------------------------------\n")
115
+ except Exception as e:
116
+ print(f"Error: {e}")
backend/amd_hackathon/verify_env.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+
5
+ def verify_environment():
6
+ print("========================================")
7
+ print(" MI300X ROCm Environment Verification ")
8
+ print("========================================")
9
+
10
+ # Check PyTorch
11
+ print(f"\n[1] PyTorch Version: {torch.__version__}")
12
+ if not torch.cuda.is_available():
13
+ print("❌ CUDA/HIP is not available. Please check your ROCm installation.")
14
+ sys.exit(1)
15
+
16
+ print("βœ… PyTorch is installed with CUDA/HIP support.")
17
+
18
+ # Check ROCm specific device properties
19
+ device_count = torch.cuda.device_count()
20
+ print(f" Available GPUs: {device_count}")
21
+ for i in range(device_count):
22
+ print(f" GPU {i}: {torch.cuda.get_device_name(i)}")
23
+ # Check VRAM
24
+ vram = torch.cuda.get_device_properties(i).total_memory / (1024 ** 3)
25
+ print(f" VRAM GPU {i}: {vram:.2f} GB")
26
+ if "MI300" in torch.cuda.get_device_name(i):
27
+ print(" βœ… MI300X detected.")
28
+
29
+ # Check DeepSpeed
30
+ print("\n[2] Checking DeepSpeed...")
31
+ try:
32
+ import deepspeed
33
+ print(f"βœ… DeepSpeed Version: {deepspeed.__version__}")
34
+ except ImportError:
35
+ print("❌ DeepSpeed is not installed.")
36
+
37
+ # Check Flash Attention
38
+ print("\n[3] Checking Flash Attention 2 (ROCm)...")
39
+ try:
40
+ import flash_attn
41
+ print(f"βœ… Flash Attention 2 Version: {flash_attn.__version__}")
42
+ except ImportError:
43
+ print("❌ Flash Attention 2 is not installed or not configured for ROCm.")
44
+
45
+ # Check Axolotl
46
+ print("\n[4] Checking Axolotl...")
47
+ try:
48
+ import axolotl
49
+ print("βœ… Axolotl is installed.")
50
+ except ImportError:
51
+ print("❌ Axolotl is not installed.")
52
+
53
+ print("\n========================================")
54
+ print(" Verification Complete ")
55
+ print("========================================")
56
+
57
+ if __name__ == "__main__":
58
+ verify_environment()