Fb56816's picture
Clasificador de proteinas Pfam - PyTorch MLP 5513 familias (92.01% accuracy)
7031d50
"""
Entrenamiento mejorado para clasificaci贸n de familias de prote铆nas.
Usa embeddings ESM-2 + features manuales para alcanzar >80% de precisi贸n.
Soporta todas las familias del Pfam con suficiente ejemplos.
"""
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.metrics import classification_report, accuracy_score
from sklearn.decomposition import PCA
import joblib
import os
import glob
import json
import argparse
AMINOACIDOS = 'ACDEFGHIKLMNPQRSTVWY'
PROPIEDADES_AA = {
'A': {'hydro': 1.8, 'polar': 0, 'charge': 0, 'size': 0, 'flex': 1, 'aromatic': 0},
'C': {'hydro': 2.5, 'polar': 0, 'charge': 0, 'size': 0, 'flex': 0, 'aromatic': 0},
'D': {'hydro': -3.5, 'polar': 1, 'charge': -1, 'size': 0, 'flex': 1, 'aromatic': 0},
'E': {'hydro': -3.5, 'polar': 1, 'charge': -1, 'size': 1, 'flex': 1, 'aromatic': 0},
'F': {'hydro': 2.8, 'polar': 0, 'charge': 0, 'size': 1, 'flex': 0, 'aromatic': 1},
'G': {'hydro': -0.4, 'polar': 0, 'charge': 0, 'size': -1, 'flex': 1, 'aromatic': 0},
'H': {'hydro': -3.2, 'polar': 1, 'charge': 0.5, 'size': 1, 'flex': 0, 'aromatic': 1},
'I': {'hydro': 4.5, 'polar': 0, 'charge': 0, 'size': 1, 'flex': 0, 'aromatic': 0},
'K': {'hydro': -3.9, 'polar': 1, 'charge': 1, 'size': 1, 'flex': 1, 'aromatic': 0},
'L': {'hydro': 3.8, 'polar': 0, 'charge': 0, 'size': 1, 'flex': 0, 'aromatic': 0},
'M': {'hydro': 1.9, 'polar': 0, 'charge': 0, 'size': 1, 'flex': 0, 'aromatic': 0},
'N': {'hydro': -3.5, 'polar': 1, 'charge': 0, 'size': 0, 'flex': 1, 'aromatic': 0},
'P': {'hydro': -1.6, 'polar': 0, 'charge': 0, 'size': 0, 'flex': -1, 'aromatic': 0},
'Q': {'hydro': -3.5, 'polar': 1, 'charge': 0, 'size': 1, 'flex': 1, 'aromatic': 0},
'R': {'hydro': -4.5, 'polar': 1, 'charge': 1, 'size': 1, 'flex': 1, 'aromatic': 0},
'S': {'hydro': -0.8, 'polar': 1, 'charge': 0, 'size': -1, 'flex': 1, 'aromatic': 0},
'T': {'hydro': -0.7, 'polar': 1, 'charge': 0, 'size': 0, 'flex': 1, 'aromatic': 0},
'V': {'hydro': 4.2, 'polar': 0, 'charge': 0, 'size': 0, 'flex': 0, 'aromatic': 0},
'W': {'hydro': -0.9, 'polar': 0, 'charge': 0, 'size': 1, 'flex': 0, 'aromatic': 1},
'Y': {'hydro': -1.3, 'polar': 1, 'charge': 0, 'size': 1, 'flex': 0, 'aromatic': 1},
}
GRUPOS_AA = {
'nonpolar': 'AILMFVW',
'polar_uncharged': 'NQSTY',
'positively_charged': 'KRH',
'negatively_charged': 'DE',
'aromatic': 'FWY',
'aliphatic': 'ILV',
'tiny': 'AGS',
'special': 'CP',
}
RUTA_DATASET = r"C:\Users\Fabi谩n\Downloads\archive\random_split\random_split"
try:
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
except NameError:
_SCRIPT_DIR = os.getcwd()
RUTA_LOCAL = os.path.join(_SCRIPT_DIR, "data")
def extraer_caracteristicas_manuales(secuencia):
if not secuencia or pd.isna(secuencia) or len(secuencia) == 0:
return [0.0] * 530
seq = secuencia.upper()
n = len(seq)
feats = []
composicion = [seq.count(aa) / n for aa in AMINOACIDOS]
feats.extend(composicion)
dipeptidos = []
for aa1 in AMINOACIDOS:
for aa2 in AMINOACIDOS:
dip = aa1 + aa2
count = seq.count(dip)
dipeptidos.append(count / max(n - 1, 1))
feats.extend(dipeptidos)
for prop_name in ['hydro', 'polar', 'charge', 'size', 'flex', 'aromatic']:
valores = [PROPIEDADES_AA.get(aa, {}).get(prop_name, 0) for aa in seq]
if valores:
feats.append(np.mean(valores))
feats.append(np.std(valores))
feats.append(np.min(valores))
if prop_name == 'hydro':
ventana = 9
hydro_mean_vals = []
for i in range(max(1, n - ventana + 1)):
window_vals = valores[i:i + ventana]
hydro_mean_vals.append(np.mean(window_vals))
if hydro_mean_vals:
feats.append(np.max(hydro_mean_vals))
feats.append(np.min(hydro_mean_vals))
feats.append(np.max(hydro_mean_vals) - np.min(hydro_mean_vals))
else:
feats.extend([0, 0, 0])
else:
feats.extend([0, 0, 0])
else:
feats.extend([0, 0, 0, 0, 0, 0])
feats.extend([
n,
np.log1p(n),
(seq.count('K') + seq.count('R') + seq.count('H')) / n,
(seq.count('D') + seq.count('E')) / n,
seq.count('C') / n,
seq.count('P') / n,
])
for grupo_nombre, grupo_aas in GRUPOS_AA.items():
freq = sum(seq.count(aa) for aa in grupo_aas) / n
feats.append(freq)
tercio = n // 3
for i in range(3):
inicio = i * tercio
fin = (i + 1) * tercio if i < 2 else n
subseq = seq[inicio:fin]
if len(subseq) > 0:
for aa in AMINOACIDOS:
feats.append(subseq.count(aa) / len(subseq))
else:
feats.extend([0.0] * 20)
return feats
def cargar_datos():
print("Cargando datos...")
rutas_busqueda = [
os.path.join(RUTA_DATASET, "train"),
os.path.join(RUTA_LOCAL, "random_split", "random_split", "train"),
os.path.join(RUTA_LOCAL, "train"),
]
archivos_train = []
for ruta in rutas_busqueda:
encontrados = glob.glob(os.path.join(ruta, "data-*"))
if encontrados:
archivos_train = encontrados
print(f"Datos encontrados en: {ruta}")
break
if not archivos_train:
archivos_train = glob.glob(os.path.join(RUTA_LOCAL, "**", "train", "data-*"), recursive=True)
if not archivos_train:
print("ERROR: No se encontraron archivos de entrenamiento")
print(f"Rutas buscadas: {rutas_busqueda}")
exit(1)
print(f"Encontrados {len(archivos_train)} archivos")
dfs = []
for i, archivo in enumerate(archivos_train):
try:
df_temp = pd.read_csv(archivo)
dfs.append(df_temp)
if (i + 1) % 20 == 0:
print(f" Le铆dos {i+1}/{len(archivos_train)} archivos ({sum(len(d) for d in dfs):,} filas)")
except Exception as e:
print(f" Error leyendo {archivo}: {e}")
if not dfs:
print("ERROR: No se pudieron leer los archivos")
exit(1)
df = pd.concat(dfs, ignore_index=True)
print(f"Total datos: {len(df):,} filas, {df['family_accession'].nunique():,} familias")
return df
def seleccionar_familias(df, min_ejemplos=50, max_familias=1000):
conteo = df['family_accession'].value_counts()
familias_validas = conteo[conteo >= min_ejemplos].index[:max_familias].tolist()
config_path = os.path.join(_SCRIPT_DIR, "config_familias.txt")
if os.path.exists(config_path):
with open(config_path, "r") as f:
familias_config = [linea.strip() for linea in f if linea.strip()]
if len(familias_config) > 5:
familias_config_validas = [f for f in familias_config if f in conteo.index]
if len(familias_config_validas) > 5:
familias_validas = familias_config_validas
print(f"Usando {len(familias_validas)} familias desde config_familias.txt")
print(f"Familias seleccionadas: {len(familias_validas)}")
print(f"Total muestras: {sum(conteo[f] for f in familias_validas if f in conteo.index):,}")
return familias_validas
def main():
parser = argparse.ArgumentParser(description='Entrenar clasificador de prote铆nas')
parser.add_argument('--no-embeddings', action='store_true',
help='Usar solo features manuales (sin ESM-2)')
parser.add_argument('--esm-model', type=str, default='esm2_t6_8M_UR50D',
choices=['esm2_t6_8M_UR50D', 'esm2_t12_35M_UR50D',
'esm2_t30_150M_UR50D', 'esm2_t33_650M_UR50D'],
help='Modelo ESM-2 a usar')
parser.add_argument('--max-familias', type=int, default=0,
help='Max familias (0=todas las disponibles)')
parser.add_argument('--min-ejemplos', type=int, default=50)
parser.add_argument('--muestra', type=int, default=0,
help='Max muestras total (0=sin limite)')
parser.add_argument('--max-per-class', type=int, default=200,
help='Max ejemplos por familia al submuestrear')
parser.add_argument('--modelo', type=str, default='auto',
choices=['auto', 'xgboost', 'sgd', 'gboost'],
help='Modelo: auto|xgboost|sgd|gboost')
args = parser.parse_args()
use_embeddings = not args.no_embeddings
df = cargar_datos()
if 'sequence' in df.columns:
df['seq_usable'] = df['sequence']
else:
df['seq_usable'] = df['aligned_sequence'].str.replace('.', '', regex=False)
max_familias = args.max_familias if args.max_familias > 0 else len(df['family_accession'].unique())
familias = seleccionar_familias(df, min_ejemplos=args.min_ejemplos, max_familias=max_familias)
df_filtrado = df[df['family_accession'].isin(familias)].copy()
print(f"\nDataset filtrado: {len(df_filtrado):,} secuencias en {len(familias)} familias")
max_per_class = args.max_per_class
needs_sampling = False
for fam, grp in df_filtrado.groupby('family_accession'):
if len(grp) > max_per_class:
needs_sampling = True
break
if needs_sampling:
print(f"Submuestreando (max {max_per_class} por familia)")
dfs_sampled = []
for fam, grp in df_filtrado.groupby('family_accession'):
n_sample = min(len(grp), max_per_class)
dfs_sampled.append(grp.sample(n=n_sample, random_state=42))
df_filtrado = pd.concat(dfs_sampled, ignore_index=True)
print(f"Despu茅s de submuestreo: {len(df_filtrado):,} secuencias")
MUESTRA = args.muestra if args.muestra > 0 else len(df_filtrado)
if len(df_filtrado) > MUESTRA:
print(f"Limitando a {MUESTRA:,} muestras totales (estratificado)")
df_filtrado = df_filtrado.groupby('family_accession').sample(
n=max(1, MUESTRA // len(familias)), random_state=42, replace=True
)
print(f"Despu茅s de l铆mite: {len(df_filtrado):,} secuencias")
df_filtrado = df_filtrado.reset_index(drop=True)
print("\nTop 15 familias:")
print(df_filtrado['family_accession'].value_counts().head(15))
print("\nExtrayendo features manuales...")
manual_feats = []
batch_size = 10000
total = len(df_filtrado)
for i in range(0, total, batch_size):
batch = df_filtrado['seq_usable'].iloc[i:i+batch_size]
batch_feats = batch.apply(extraer_caracteristicas_manuales)
manual_feats.extend(batch_feats.tolist())
if (i + batch_size) % 50000 == 0 or i + batch_size >= total:
print(f" {min(i+batch_size, total):,}/{total:,} secuencias")
X_manual = np.array(manual_feats, dtype=np.float32)
y = df_filtrado['family_accession'].values
secuencias = df_filtrado['seq_usable'].tolist()
print(f"Features manuales: {X_manual.shape}")
if use_embeddings:
try:
from embeddings import ProteinEmbedder, has_embeddings_cache, load_embeddings_cache
cache_dir = "modelos/embeddings_cache"
if has_embeddings_cache(cache_dir):
print("\nCargando embeddings desde cache...")
X_emb, cached_labels, metadata = load_embeddings_cache(cache_dir)
if len(X_emb) == len(y):
print(f"Cache v谩lido: {X_emb.shape}")
else:
print(f"Cache desactualizado ({len(X_emb)} vs {len(y)}). Regenerando...")
X_emb = None
if not has_embeddings_cache(cache_dir) or (use_embeddings and 'X_emb' not in dir()):
print(f"\nGenerando embeddings ESM-2 ({args.esm_model})...")
embedder = ProteinEmbedder(model_name=args.esm_model)
emb_batch_size = 4 if embedder.device.type == 'cuda' else 2
X_emb = embedder.embed_batch(secuencias, batch_size=emb_batch_size)
np.save(os.path.join(cache_dir, "embeddings.npy"), X_emb)
np.save(os.path.join(cache_dir, "labels.npy"), y)
print(f"Embeddings: {X_emb.shape}")
print("\nCombinando embeddings + features manuales...")
scaler_manual = StandardScaler()
X_manual_scaled = scaler_manual.fit_transform(X_manual)
scaler_emb = StandardScaler()
X_emb_scaled = scaler_emb.fit_transform(X_emb)
X = np.hstack([X_emb_scaled, X_manual_scaled])
print(f"Features combinados: {X.shape}")
except Exception as e:
print(f"\nError con embeddings: {e}")
print("Usando solo features manuales...")
use_embeddings = False
X = X_manual
scaler_manual = StandardScaler()
X = scaler_manual.fit_transform(X)
scaler_emb = None
else:
X = X_manual
scaler_manual = StandardScaler()
X = scaler_manual.fit_transform(X)
scaler_emb = None
print("\nCodificando etiquetas...")
label_encoder = LabelEncoder()
y_encoded = label_encoder.fit_transform(y)
num_clases = len(label_encoder.classes_)
print(f"Clases: {num_clases}")
print("\nDividiendo datos (80/20)...")
X_train, X_test, y_train, y_test = train_test_split(
X, y_encoded, test_size=0.2, random_state=42, stratify=y_encoded
)
print(f"Entrenamiento: {len(X_train):,} | Test: {len(X_test):,}")
print("\nEntrenando modelo...")
model_type = args.modelo
if model_type == 'auto':
model_type = 'sgd' if num_clases > 500 else 'gboost' if num_clases <= 100 else 'xgboost'
if model_type == 'gboost' and num_clases <= 100:
print("Usando GradientBoostingClassifier")
modelo = GradientBoostingClassifier(
n_estimators=300, max_depth=6, learning_rate=0.1,
subsample=0.8, min_samples_split=5, min_samples_leaf=3,
max_features='sqrt', random_state=42, verbose=1,
)
modelo.fit(X_train, y_train)
elif model_type == 'sgd':
print(f"Usando SGDClassifier ({num_clases} clases, ultra eficiente)")
from sklearn.linear_model import SGDClassifier
modelo = SGDClassifier(
loss='log_loss',
penalty='l2',
alpha=1e-4,
max_iter=1000,
tol=1e-3,
random_state=42,
n_jobs=-1,
verbose=1,
)
modelo.fit(X_train, y_train)
else:
print(f"Usando XGBoost ({num_clases} clases)")
from xgboost import XGBClassifier
modelo = XGBClassifier(
n_estimators=200, max_depth=3, learning_rate=0.3,
subsample=0.7, colsample_bytree=0.3, min_child_weight=20,
num_class=num_clases, objective='multi:softprob',
eval_metric='mlogloss', tree_method='hist',
n_jobs=-1, random_state=42, verbosity=1,
)
modelo.fit(X_train, y_train)
print("\n=== EVALUACI脫N ===")
y_pred = modelo.predict(X_test)
acc = accuracy_score(y_test, y_pred)
print(f"\nPrecisi贸n en test: {acc:.4f} ({acc*100:.1f}%)")
print("\nReporte por familia (resumido):")
print(classification_report(y_test, y_pred, zero_division=0))
os.makedirs("modelos", exist_ok=True)
joblib.dump(modelo, "modelos/clasificador_proteinas.joblib", compress=3)
joblib.dump(label_encoder, "modelos/label_encoder.joblib", compress=3)
if scaler_manual is not None:
joblib.dump(scaler_manual, "modelos/scaler_manual.joblib", compress=3)
if scaler_emb is not None:
joblib.dump(scaler_emb, "modelos/scaler_emb.joblib", compress=3)
feature_info = {
'num_features': X.shape[1],
'use_embeddings': use_embeddings,
'esm_model': args.esm_model if use_embeddings else None,
'manual_features': X_manual.shape[1],
'embedding_dim': X_emb.shape[1] if use_embeddings and 'X_emb' in dir() else 0,
'familias': list(label_encoder.classes_),
'accuracy': float(acc),
'num_clases': num_clases,
'model_type': model_type,
}
with open("modelos/feature_info.json", "w") as f:
json.dump(feature_info, f, indent=2)
print(f"\n[OK] Modelo guardado en 'modelos/clasificador_proteinas.joblib'")
print(f"[OK] LabelEncoder en 'modelos/label_encoder.joblib'")
if hasattr(modelo, 'feature_importances_'):
print("\n=== Importancia de caracter铆sticas (top 15) ===")
importancias = modelo.feature_importances_
top_idx = np.argsort(importancias)[::-1][:15]
if use_embeddings and 'X_emb' in dir():
emb_dim = X_emb.shape[1]
for idx in top_idx:
if idx < emb_dim:
print(f" ESM_dim_{idx}: {importancias[idx]:.4f}")
else:
manual_idx = idx - emb_dim
print(f" Manual_feat_{manual_idx}: {importancias[idx]:.4f}")
else:
nombres = (
[f'comp_{aa}' for aa in AMINOACIDOS] +
[f'dip_{aa1}{aa2}' for aa1 in AMINOACIDOS for aa2 in AMINOACIDOS] +
['hydro_mean', 'hydro_std', 'hydro_min', 'hydro_max_win', 'hydro_min_win', 'hydro_range_win'] +
['polar_mean', 'polar_std', 'polar_min',
'flex_mean', 'flex_std', 'flex_min',
'aromatic_mean', 'aromatic_std', 'aromatic_min',
'charge_mean', 'charge_std', 'charge_min',
'size_mean', 'size_std', 'size_min'] +
['longitud', 'log_longitud', 'carga_pos', 'carga_neg', 'freq_C', 'freq_P'] +
[f'grupo_{g}' for g in GRUPOS_AA.keys()] +
[f'tercio{i}_{aa}' for i in range(3) for aa in AMINOACIDOS]
)
for idx in top_idx:
nombre = nombres[idx] if idx < len(nombres) else f'feat_{idx}'
print(f" {nombre}: {importancias[idx]:.4f}")
print(f"\n{'='*50}")
print(f"RESUMEN FINAL")
print(f"{'='*50}")
print(f"Familias: {num_clases}")
print(f"Features: {X.shape[1]} {'(embeddings + manuales)' if use_embeddings else '(solo manuales)'}")
print(f"Precisi贸n: {acc*100:.1f}%")
if acc >= 0.80:
print("OBJETIVO ALCANZADO: >= 80% de precisi贸n")
else:
print(f"Faltan {(0.80 - acc)*100:.1f}% para el objetivo de 80%")
if not use_embeddings:
print("Prueba con embeddings: python train.py (sin --no-embeddings)")
if args.esm_model == 'esm2_t6_8M_UR50D':
print("Prueba modelo ESM m谩s grande: python train.py --esm-model esm2_t12_35M_UR50D")
if __name__ == "__main__":
main()