cracker0935 commited on
Commit
eb48929
·
verified ·
1 Parent(s): 34732f8

add Backend_files

Browse files
Files changed (6) hide show
  1. Dockerfile +23 -0
  2. best_alzheimer_model.pth +3 -0
  3. main.py +150 -0
  4. model_arch.py +73 -0
  5. preprocessing.py +87 -0
  6. requirements.txt +9 -0
Dockerfile ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use Python 3.9
2
+ FROM python:3.9
3
+
4
+ # Set working directory
5
+ WORKDIR /app
6
+
7
+ # Copy requirements first to leverage Docker cache
8
+ COPY requirements.txt .
9
+
10
+ # Install dependencies
11
+ # Added --no-cache-dir to keep image small
12
+ RUN pip install --no-cache-dir -r requirements.txt
13
+
14
+ # Copy the rest of the application
15
+ COPY . .
16
+
17
+ # Create a writable directory for the model download
18
+ # Hugging Face Spaces runs as a non-root user (user 1000)
19
+ RUN mkdir -p /app/model_cache && chmod 777 /app/model_cache
20
+
21
+ # Start the application
22
+ # Note: HF Spaces expects the app to run on port 7860
23
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
best_alzheimer_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f96c6a0d0eccc515b9e17534436941870915081cbed439c50021c3051dd4fb54
3
+ size 574490709
main.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File, HTTPException
2
+ from pydantic import BaseModel
3
+ from contextlib import asynccontextmanager
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from transformers import AutoTokenizer
7
+ from typing import List
8
+ from fastapi.middleware.cors import CORSMiddleware
9
+ import os
10
+ import requests # <--- Added this
11
+
12
+ from model_arch import ResearchHybridModel
13
+ from preprocessing import ChaParser
14
+
15
+ CONFIG = {
16
+ 'model_name': 'microsoft/deberta-base',
17
+ 'max_seq_len': 64,
18
+ 'max_word_len': 40,
19
+ 'device': torch.device("cuda" if torch.cuda.is_available() else "cpu"),
20
+ 'threshold': 0.20,
21
+ # PASTE YOUR COPIED HUGGING FACE LINK BELOW
22
+ 'model_url': os.getenv('MODEL_URL')
23
+ }
24
+
25
+ ml_components = {}
26
+
27
+ @asynccontextmanager
28
+ async def lifespan(app: FastAPI):
29
+ print("Loading Model and Tokenizer...")
30
+ ml_components['tokenizer'] = AutoTokenizer.from_pretrained(CONFIG['model_name'])
31
+
32
+ # --- MODEL DOWNLOAD LOGIC START ---
33
+ model_path = "best_alzheimer_model.pth"
34
+
35
+ if not os.path.exists(model_path):
36
+ print(f"Model file not found. Downloading from Hugging Face...")
37
+ try:
38
+ response = requests.get(CONFIG['model_url'], stream=True)
39
+ response.raise_for_status()
40
+ with open(model_path, "wb") as f:
41
+ for chunk in response.iter_content(chunk_size=8192):
42
+ f.write(chunk)
43
+ print("Download complete.")
44
+ except Exception as e:
45
+ print(f"Error downloading model: {e}")
46
+ raise RuntimeError("Failed to download model file")
47
+ # --- MODEL DOWNLOAD LOGIC END ---
48
+
49
+ model = ResearchHybridModel(model_name=CONFIG['model_name'])
50
+
51
+ # Load state dict
52
+ state_dict = torch.load(model_path, map_location=CONFIG['device'])
53
+
54
+ if list(state_dict.keys())[0].startswith('module.'):
55
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
56
+
57
+ model.load_state_dict(state_dict)
58
+ model.to(CONFIG['device'])
59
+ model.eval()
60
+ ml_components['model'] = model
61
+ print("Model Loaded Successfully.")
62
+ yield
63
+ ml_components.clear()
64
+
65
+ app = FastAPI(lifespan=lifespan)
66
+
67
+ # Add CORS middleware
68
+ app.add_middleware(
69
+ CORSMiddleware,
70
+ allow_origins=["https://adtrack.onrender.com"],
71
+ allow_credentials=True,
72
+ allow_methods=["*"],
73
+ allow_headers=["*"],
74
+ )
75
+
76
+ class SentenceAttention(BaseModel):
77
+ sentence: str
78
+ attention_score: float
79
+
80
+ class PredictionResponse(BaseModel):
81
+ filename: str
82
+ prediction: str
83
+ confidence: float
84
+ is_dementia: bool
85
+ attention_map: List[SentenceAttention]
86
+
87
+ @app.post("/predict/cha", response_model=PredictionResponse)
88
+ async def predict_cha_file(file: UploadFile = File(...)):
89
+ if not file.filename.endswith('.cha'):
90
+ raise HTTPException(status_code=400, detail="Only .cha files are supported")
91
+
92
+ contents = await file.read()
93
+ lines = contents.splitlines()
94
+
95
+ parser = ChaParser()
96
+ sentences, features, _ = parser.parse(lines)
97
+
98
+ if not sentences:
99
+ raise HTTPException(status_code=400, detail="No *PAR lines found in file")
100
+
101
+ if len(sentences) > CONFIG['max_seq_len']:
102
+ sentences = sentences[-CONFIG['max_seq_len']:]
103
+ features = features[-CONFIG['max_seq_len']:]
104
+
105
+ tokenizer = ml_components['tokenizer']
106
+ model = ml_components['model']
107
+
108
+ encoding = tokenizer(
109
+ sentences,
110
+ padding='max_length',
111
+ truncation=True,
112
+ max_length=CONFIG['max_word_len'],
113
+ return_tensors='pt'
114
+ )
115
+
116
+ ids = encoding['input_ids'].unsqueeze(0).to(CONFIG['device'])
117
+ mask = encoding['attention_mask'].unsqueeze(0).to(CONFIG['device'])
118
+ feats = torch.tensor(features, dtype=torch.float32).unsqueeze(0).to(CONFIG['device'])
119
+ lengths = torch.tensor([len(sentences)])
120
+
121
+ with torch.no_grad():
122
+ logits, attn_weights_tensor = model(ids, mask, feats, lengths)
123
+ prob = F.softmax(logits, dim=1)[:, 1].item()
124
+
125
+ attn_weights = attn_weights_tensor.cpu().numpy().flatten()
126
+ attn_weights = attn_weights[:len(sentences)]
127
+
128
+ # Normalize attention for frontend display
129
+ if len(attn_weights) > 0:
130
+ w_min, w_max = attn_weights.min(), attn_weights.max()
131
+ if w_max - w_min > 0:
132
+ attn_weights = (attn_weights - w_min) / (w_max - w_min)
133
+
134
+ prediction_label = "DEMENTIA" if prob >= CONFIG['threshold'] else "HEALTHY CONTROL"
135
+
136
+ attention_map = []
137
+ for sent, score in zip(sentences, attn_weights):
138
+ attention_map.append(SentenceAttention(sentence=sent, attention_score=float(score)))
139
+
140
+ return {
141
+ "filename": file.filename,
142
+ "prediction": prediction_label,
143
+ "confidence": prob,
144
+ "is_dementia": prob >= CONFIG['threshold'],
145
+ "attention_map": attention_map
146
+ }
147
+
148
+ @app.get("/health")
149
+ def health_check():
150
+ return {"status": "active", "device": str(CONFIG['device'])}
model_arch.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from transformers import AutoModel
5
+
6
+ class GatedFeatureFusion(nn.Module):
7
+ def __init__(self, embed_dim, feature_dim):
8
+ super().__init__()
9
+ self.feat_proj = nn.Linear(feature_dim, embed_dim)
10
+ self.gate = nn.Sequential(
11
+ nn.Linear(embed_dim * 2, embed_dim),
12
+ nn.Sigmoid()
13
+ )
14
+ self.norm = nn.LayerNorm(embed_dim)
15
+
16
+ def forward(self, text_embeds, raw_features):
17
+ feat_embeds = F.relu(self.feat_proj(raw_features))
18
+ combined = torch.cat([text_embeds, feat_embeds], dim=2)
19
+ z = self.gate(combined)
20
+ fused = z * text_embeds + (1 - z) * feat_embeds
21
+ return self.norm(fused)
22
+
23
+ class ResearchHybridModel(nn.Module):
24
+ def __init__(self, model_name='microsoft/deberta-base', feature_dim=6):
25
+ super().__init__()
26
+ self.bert = AutoModel.from_pretrained(model_name)
27
+ self.bert_hidden = 768
28
+
29
+ self.fusion = GatedFeatureFusion(self.bert_hidden, feature_dim)
30
+
31
+ self.lstm = nn.LSTM(
32
+ input_size=self.bert_hidden,
33
+ hidden_size=256,
34
+ num_layers=2,
35
+ batch_first=True,
36
+ bidirectional=True,
37
+ dropout=0.3
38
+ )
39
+
40
+ self.attention = nn.Sequential(
41
+ nn.Linear(512, 128),
42
+ nn.Tanh(),
43
+ nn.Linear(128, 1)
44
+ )
45
+
46
+ self.classifier = nn.Sequential(
47
+ nn.Linear(512, 128),
48
+ nn.BatchNorm1d(128),
49
+ nn.ReLU(),
50
+ nn.Dropout(0.4),
51
+ nn.Linear(128, 2)
52
+ )
53
+
54
+ def forward(self, input_ids, attention_mask, linguistic_features, lengths):
55
+ batch_size, seq_len, word_len = input_ids.shape
56
+ flat_input = input_ids.view(-1, word_len)
57
+ flat_mask = attention_mask.view(-1, word_len)
58
+ bert_out = self.bert(flat_input, attention_mask=flat_mask).last_hidden_state
59
+ sent_embeds = bert_out[:, 0, :].view(batch_size, seq_len, -1)
60
+
61
+ fused = self.fusion(sent_embeds, linguistic_features)
62
+
63
+ packed = torch.nn.utils.rnn.pack_padded_sequence(fused, lengths.cpu(), batch_first=True, enforce_sorted=False)
64
+ packed_out, _ = self.lstm(packed)
65
+ lstm_out, _ = torch.nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True, total_length=seq_len)
66
+
67
+ attn_scores = self.attention(lstm_out)
68
+ mask = (torch.arange(seq_len, device=input_ids.device)[None, :] < lengths.to(input_ids.device)[:, None]).float().unsqueeze(2)
69
+ attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
70
+ attn_weights = F.softmax(attn_scores, dim=1)
71
+
72
+ context = torch.sum(lstm_out * attn_weights, dim=1)
73
+ return self.classifier(context), attn_weights.squeeze()
preprocessing.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import numpy as np
3
+
4
+ class LiveFeatureExtractor:
5
+ def __init__(self):
6
+ self.patterns = {
7
+ 'fillers': re.compile(r'&-([a-z]+)', re.IGNORECASE),
8
+ 'repetition': re.compile(r'\[/+\]'),
9
+ 'retracing': re.compile(r'\[//\]'),
10
+ 'incomplete': re.compile(r'\+[\./]+'),
11
+ 'errors': re.compile(r'\[\*.*?\]'),
12
+ 'pauses': re.compile(r'\(\.+\)')
13
+ }
14
+
15
+ def clean_for_bert(self, raw_text):
16
+ text = re.sub(r'^\*PAR:\s+', '', raw_text)
17
+ text = re.sub(r'\x15\d+_\d+\x15', '', text)
18
+ text = re.sub(r'<|>', '', text)
19
+ text = re.sub(r'\[.*?\]', '', text)
20
+ text = re.sub(r'\(\.+\)', '[PAUSE]', text)
21
+ text = text.replace('_', ' ')
22
+ text = re.sub(r'\s+', ' ', text).strip()
23
+ return text
24
+
25
+ def get_features(self, raw_text):
26
+ stats = {k: len(p.findall(raw_text)) for k, p in self.patterns.items()}
27
+ clean_for_stats = re.sub(r'\[.*?\]', '', raw_text)
28
+ clean_for_stats = re.sub(r'&-([a-z]+)', '', clean_for_stats)
29
+ clean_for_stats = re.sub(r'[^\w\s]', '', clean_for_stats)
30
+ words = clean_for_stats.lower().split()
31
+ stats['word_count'] = len(words)
32
+ return stats
33
+
34
+ def get_vector(self, raw_text, global_ttr_override=None):
35
+ stats = self.get_features(raw_text)
36
+ n = stats['word_count'] if stats['word_count'] > 0 else 1
37
+
38
+ if global_ttr_override is not None:
39
+ ttr = global_ttr_override
40
+ else:
41
+ clean_for_stats = re.sub(r'\[.*?\]', '', raw_text)
42
+ clean_for_stats = re.sub(r'&-([a-z]+)', '', clean_for_stats)
43
+ clean_for_stats = re.sub(r'[^\w\s]', '', clean_for_stats)
44
+ words = clean_for_stats.lower().split()
45
+ ttr = (len(set(words)) / n) if n > 0 else 0.0
46
+
47
+ return [
48
+ ttr,
49
+ stats['fillers']/n,
50
+ stats['repetition']/n,
51
+ stats['retracing']/n,
52
+ stats['errors']/n,
53
+ stats['pauses']/n
54
+ ]
55
+
56
+ class ChaParser:
57
+ def __init__(self):
58
+ self.extractor = LiveFeatureExtractor()
59
+
60
+ def parse(self, file_content_lines):
61
+ sentences = []
62
+ features = []
63
+ raw_lines = []
64
+ all_words_in_session = []
65
+
66
+ decoded_lines = [line.decode('utf-8') if isinstance(line, bytes) else line for line in file_content_lines]
67
+
68
+ for line in decoded_lines:
69
+ if line.startswith('*PAR:'):
70
+ clean_line = re.sub(r'[^\w\s]', '', line.replace('*PAR:', ''))
71
+ words = clean_line.lower().split()
72
+ all_words_in_session.extend(words)
73
+
74
+ unique_words = len(set(all_words_in_session))
75
+ total_words = len(all_words_in_session)
76
+ global_ttr = unique_words / total_words if total_words > 0 else 0.0
77
+
78
+ for line in decoded_lines:
79
+ if line.startswith('*PAR:'):
80
+ display_text = self.extractor.clean_for_bert(line)
81
+ feat_vec = self.extractor.get_vector(line, global_ttr_override=global_ttr)
82
+
83
+ sentences.append(display_text)
84
+ features.append(feat_vec)
85
+ raw_lines.append(line.strip())
86
+
87
+ return sentences, features, raw_lines
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ python-multipart
4
+ torch
5
+ transformers
6
+ numpy
7
+ regex
8
+ requests
9
+ python-dotenv