| """Chart Pattern Detection API — YOLOv8""" |
| import torch |
| import torch.serialization |
| import gradio as gr |
| import json |
|
|
| |
| try: |
| import ultralytics.nn.tasks as tasks |
| safe_classes = [getattr(tasks, c) for c in dir(tasks) if isinstance(getattr(tasks, c, None), type)] |
| torch.serialization.add_safe_globals(safe_classes) |
| except: pass |
|
|
| |
| _orig = torch.load |
| def _force(*a, **kw): |
| kw["weights_only"] = False |
| return _orig(*a, **kw) |
| torch.load = _force |
|
|
| from huggingface_hub import hf_hub_download |
| from ultralytics import YOLO |
|
|
| model_path = hf_hub_download(repo_id="foduucom/stockmarket-pattern-detection-yolov8", filename="model.pt") |
| model = YOLO(model_path) |
| print(f"Classes: {model.names}") |
|
|
| |
| from PIL import Image |
| import numpy as np |
| test_img = Image.fromarray(np.zeros((640,640,3), dtype=np.uint8)) |
| test_res = model.predict(source=test_img, conf=0.01, verbose=False) |
| print(f"Self-test: {len(test_res[0].boxes)} detections on blank (should be 0)") |
|
|
| def detect_patterns(image): |
| try: |
| if image is None: |
| return json.dumps({"patterns": []}) |
| results = model.predict(source=image, conf=0.20, iou=0.45, imgsz=640, verbose=False) |
| patterns = [] |
| for r in results: |
| if r.boxes is None or len(r.boxes) == 0: continue |
| for i in range(len(r.boxes)): |
| box = r.boxes[i] |
| patterns.append({ |
| "label": r.names.get(int(box.cls[0]), "unknown"), |
| "confidence": round(float(box.conf[0]), 3), |
| "bbox": [round(float(x), 1) for x in box.xyxy[0].tolist()], |
| }) |
| patterns.sort(key=lambda p: p["confidence"], reverse=True) |
| return json.dumps({"patterns": patterns, "count": len(patterns)}) |
| except Exception as e: |
| return json.dumps({"patterns": [], "error": str(e)}) |
|
|
| demo = gr.Interface(fn=detect_patterns, inputs=gr.Image(type="pil"), outputs=gr.Textbox(), |
| title="Chart Pattern Detection — YOLOv8") |
| demo.launch(show_error=True) |
|
|