File size: 2,160 Bytes
c565b1a
0485e2a
 
b824915
 
0485e2a
b824915
0485e2a
b824915
 
 
 
0485e2a
b824915
 
 
 
 
 
 
 
 
 
 
 
 
5511a7a
b824915
 
 
 
 
 
2a6846e
 
c565b1a
 
b824915
 
c565b1a
 
b824915
c565b1a
 
 
e3ced66
 
 
c565b1a
 
c2a5a49
c565b1a
e3ced66
2a6846e
b824915
 
c565b1a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
"""Chart Pattern Detection API — YOLOv8"""
import torch
import torch.serialization
import gradio as gr
import json

# Allowlist ALL ultralytics model classes for safe loading
try:
    import ultralytics.nn.tasks as tasks
    safe_classes = [getattr(tasks, c) for c in dir(tasks) if isinstance(getattr(tasks, c, None), type)]
    torch.serialization.add_safe_globals(safe_classes)
except: pass

# Also force weights_only=False globally
_orig = torch.load
def _force(*a, **kw):
    kw["weights_only"] = False
    return _orig(*a, **kw)
torch.load = _force

from huggingface_hub import hf_hub_download
from ultralytics import YOLO

model_path = hf_hub_download(repo_id="foduucom/stockmarket-pattern-detection-yolov8", filename="model.pt")
model = YOLO(model_path)
print(f"Classes: {model.names}")

# Quick self-test with a blank image
from PIL import Image
import numpy as np
test_img = Image.fromarray(np.zeros((640,640,3), dtype=np.uint8))
test_res = model.predict(source=test_img, conf=0.01, verbose=False)
print(f"Self-test: {len(test_res[0].boxes)} detections on blank (should be 0)")

def detect_patterns(image):
    try:
        if image is None:
            return json.dumps({"patterns": []})
        results = model.predict(source=image, conf=0.20, iou=0.45, imgsz=640, verbose=False)
        patterns = []
        for r in results:
            if r.boxes is None or len(r.boxes) == 0: continue
            for i in range(len(r.boxes)):
                box = r.boxes[i]
                patterns.append({
                    "label": r.names.get(int(box.cls[0]), "unknown"),
                    "confidence": round(float(box.conf[0]), 3),
                    "bbox": [round(float(x), 1) for x in box.xyxy[0].tolist()],
                })
        patterns.sort(key=lambda p: p["confidence"], reverse=True)
        return json.dumps({"patterns": patterns, "count": len(patterns)})
    except Exception as e:
        return json.dumps({"patterns": [], "error": str(e)})

demo = gr.Interface(fn=detect_patterns, inputs=gr.Image(type="pil"), outputs=gr.Textbox(),
                    title="Chart Pattern Detection — YOLOv8")
demo.launch(show_error=True)