Rochic commited on
Commit
138677a
·
verified ·
1 Parent(s): e67b24e

Upload 6 files

Browse files
Files changed (6) hide show
  1. .gitignore +4 -4
  2. README.md +13 -13
  3. app.py +51 -51
  4. main.py +145 -135
  5. requirements.txt +16 -16
  6. test_local.py +14 -0
.gitignore CHANGED
@@ -1,4 +1,4 @@
1
- venv
2
- __pycache__/
3
- *.pyc
4
- .DS_Store
 
1
+ venv
2
+ __pycache__/
3
+ *.pyc
4
+ .DS_Store
README.md CHANGED
@@ -1,13 +1,13 @@
1
- ---
2
- title: MolGenAI
3
- emoji: 🧬
4
- colorFrom: blue
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: "4.38.1"
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- # 🧬 MolGenAI
13
- Generador molecular basado en GPT-2 para estructuras SMILES.
 
1
+ ---
2
+ title: MolGenAI
3
+ emoji: 🧬
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: "4.38.1"
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ # 🧬 MolGenAI
13
+ Generador molecular basado en GPT-2 para estructuras SMILES.
app.py CHANGED
@@ -1,51 +1,51 @@
1
- import gradio as gr
2
- import torch
3
- import main as core
4
-
5
- # ---------- Función de inferencia ----------
6
- def run_inference(input_text: str):
7
- if not input_text.strip():
8
- return "Ingresá una configuración para generar el resultado."
9
-
10
- if core.model is None or core.tokenizer is None:
11
- return "El modelo no está cargado correctamente."
12
-
13
- try:
14
- # Tokenización e inferencia
15
- inputs = core.tokenizer(input_text, return_tensors="pt").to(core.DEVICE)
16
- with torch.no_grad():
17
- outputs = core.model.generate(
18
- inputs["input_ids"],
19
- max_length=60,
20
- do_sample=True,
21
- top_p=0.95,
22
- temperature=0.8
23
- )
24
-
25
- # Decodificación
26
- tokens = core.tokenizer.convert_ids_to_tokens(outputs[0])
27
- tokens_string = core.decodificar_tokens(tokens)
28
- smiles = core.postprocesar_smiles(tokens_string)
29
-
30
- if not smiles or len(smiles.strip()) == 0:
31
- return "No se generó ningún SMILES válido."
32
-
33
- return smiles
34
-
35
- except Exception as e:
36
- return f"Error interno durante la generación: {str(e)}"
37
-
38
- # ---------- Interfaz de Gradio ----------
39
- with gr.Blocks(title="MolGen.AI") as demo:
40
- gr.Markdown("## 🧬 MolGen.AI — Generación de moléculas")
41
- gr.Markdown("Escribí una configuración y generá una estructura SMILES basada en tu modelo.")
42
- inp = gr.Textbox(label="Configuración", placeholder="Ej: CCO[NH2+]...", lines=3)
43
- btn = gr.Button("Generar", variant="primary")
44
- out = gr.Textbox(label="SMILES generados", lines=6)
45
- btn.click(fn=run_inference, inputs=inp, outputs=out)
46
-
47
- if __name__ == "__main__":
48
- print("🚀 Iniciando MolGen.AI con Gradio...")
49
- import os
50
- port = int(os.environ.get("PORT", 7860))
51
- demo.launch(server_name="0.0.0.0", server_port=port)
 
1
+ import gradio as gr
2
+ import torch
3
+ import main as core
4
+
5
+ # ---------- Función de inferencia ----------
6
+ def run_inference(input_text: str):
7
+ if not input_text.strip():
8
+ return "Ingresá una configuración para generar el resultado."
9
+
10
+ if core.model is None or core.tokenizer is None:
11
+ return "El modelo no está cargado correctamente."
12
+
13
+ try:
14
+ # Tokenización e inferencia
15
+ inputs = core.tokenizer(input_text, return_tensors="pt").to(core.DEVICE)
16
+ with torch.no_grad():
17
+ outputs = core.model.generate(
18
+ inputs["input_ids"],
19
+ max_length=60,
20
+ do_sample=True,
21
+ top_p=0.95,
22
+ temperature=0.8
23
+ )
24
+
25
+ # Decodificación
26
+ tokens = core.tokenizer.convert_ids_to_tokens(outputs[0])
27
+ tokens_string = core.decodificar_tokens(tokens)
28
+ smiles = core.postprocesar_smiles(tokens_string)
29
+
30
+ if not smiles or len(smiles.strip()) == 0:
31
+ return "No se generó ningún SMILES válido."
32
+
33
+ return smiles
34
+
35
+ except Exception as e:
36
+ return f"Error interno durante la generación: {str(e)}"
37
+
38
+ # ---------- Interfaz de Gradio ----------
39
+ with gr.Blocks(title="MolGen.AI") as demo:
40
+ gr.Markdown("## 🧬 MolGen.AI — Generación de moléculas")
41
+ gr.Markdown("Escribí una configuración y generá una estructura SMILES basada en tu modelo.")
42
+ inp = gr.Textbox(label="Configuración", placeholder="Ej: CCO[NH2+]...", lines=3)
43
+ btn = gr.Button("Generar", variant="primary")
44
+ out = gr.Textbox(label="SMILES generados", lines=6)
45
+ btn.click(fn=run_inference, inputs=inp, outputs=out)
46
+
47
+ if __name__ == "__main__":
48
+ print("🚀 Iniciando MolGen.AI con Gradio...")
49
+ import os
50
+ port = int(os.environ.get("PORT", 7860))
51
+ demo.launch(server_name="0.0.0.0", server_port=port)
main.py CHANGED
@@ -1,135 +1,145 @@
1
- import os
2
- import re
3
- import torch
4
- from typing import Optional
5
- from fastapi import FastAPI, HTTPException
6
- from pydantic import BaseModel
7
- from transformers import AutoTokenizer, AutoModelForCausalLM
8
-
9
- app = FastAPI(title="Chem SMILES Generator", version="1.0.0")
10
-
11
- # ---------- Config ----------
12
- MODEL_NAME = os.getenv("MODEL_NAME", "ncfrey/ChemGPT-4.7M")
13
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14
- SPECIAL_TOKENS = {"[CLS]", "[SEP]", "[PAD]", "[UNK]", "[BOS]", "[EOS]", "[MASK]"}
15
-
16
- # ---------- Modelo global (cargado una vez) ----------
17
- tokenizer = None
18
- model = None
19
-
20
- @app.on_event("startup")
21
- def load_model():
22
- global tokenizer, model
23
- try:
24
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
25
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
26
- model.to(DEVICE)
27
- model.eval()
28
- except Exception as e:
29
- raise RuntimeError(f"No se pudo cargar el modelo '{MODEL_NAME}': {e}")
30
-
31
- # ---------- Utilidades de decodificación/postpro ----------
32
- def decodificar_tokens(tokens):
33
- mol = []
34
- for tok in tokens:
35
- if tok in SPECIAL_TOKENS:
36
- continue
37
- if tok.startswith("[") and tok.endswith("]"):
38
- contenido = tok[1:-1]
39
- if re.match(r'^[A-Za-z0-9@=#+\\/-]+$', contenido):
40
- mol.append(contenido)
41
- else:
42
- mol.append(tok)
43
- else:
44
- mol.append(tok)
45
- return "".join(mol)
46
-
47
- def postprocesar_smiles(tokens_string: str) -> str:
48
- pattern = re.compile(r'\[.*?\]')
49
- tokens = pattern.split(tokens_string)
50
- matches = pattern.findall(tokens_string)
51
-
52
- result = []
53
- branch_stack = []
54
- ring_open = {}
55
-
56
- for i in range(len(tokens)):
57
- result.append(tokens[i])
58
-
59
- if i < len(matches):
60
- tok = matches[i]
61
-
62
- if tok.startswith("[Branch"):
63
- result.append("(")
64
- branch_stack.append(")")
65
-
66
- elif tok.startswith("[Ring"):
67
- nums = re.findall(r'\d+', tok)
68
- if nums:
69
- n = nums[0]
70
- if n not in ring_open:
71
- ring_open[n] = True
72
- else:
73
- del ring_open[n]
74
- result.append(n)
75
-
76
- else:
77
- result.append(tok)
78
-
79
- while branch_stack:
80
- result.append(branch_stack.pop())
81
-
82
- return "".join(result)
83
-
84
- # ---------- Tipos de request/response ----------
85
- class GenerateRequest(BaseModel):
86
- input_text: str
87
- max_length: Optional[int] = 60
88
- top_k: Optional[int] = 50
89
- top_p: Optional[float] = 0.95
90
- temperature: Optional[float] = 1.0
91
-
92
- class GenerateResponse(BaseModel):
93
- raw_tokens_string: str
94
- smiles_postprocesado: str
95
-
96
- # ---------- Rutas ----------
97
- @app.get("/health")
98
- def health():
99
- return {"status": "ok", "device": DEVICE, "model": MODEL_NAME}
100
-
101
- @app.post("/generate", response_model=GenerateResponse)
102
- def generate(req: GenerateRequest):
103
- if tokenizer is None or model is None:
104
- raise HTTPException(status_code=500, detail="Modelo no inicializado")
105
-
106
- try:
107
- inputs = tokenizer(req.input_text, return_tensors="pt").to(DEVICE)
108
- # Preferimos el token [EOS] si existe, si no el por defecto del tokenizer
109
- eos_id = tokenizer.convert_tokens_to_ids("[EOS]")
110
- if eos_id is None or eos_id == tokenizer.unk_token_id:
111
- eos_id = tokenizer.eos_token_id
112
-
113
- with torch.no_grad():
114
- outputs = model.generate(
115
- inputs["input_ids"],
116
- max_length=req.max_length,
117
- do_sample=True,
118
- top_k=req.top_k,
119
- top_p=req.top_p,
120
- temperature=req.temperature,
121
- eos_token_id=eos_id,
122
- )
123
-
124
- tokens = tokenizer.convert_ids_to_tokens(outputs[0])
125
- tokens_string = decodificar_tokens(tokens)
126
- smiles = postprocesar_smiles(tokens_string)
127
- smiles_final = smiles.replace("Ring", "")
128
-
129
- return GenerateResponse(
130
- raw_tokens_string=tokens_string,
131
- smiles_postprocesado=smiles
132
- )
133
- except Exception as e:
134
- raise HTTPException(status_code=500, detail=f"Error generando SMILES: {str(e)}")
135
-
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import torch
4
+ from typing import Optional
5
+ from fastapi import FastAPI, HTTPException
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+ from pydantic import BaseModel
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM
9
+
10
+ app = FastAPI(title="Chem SMILES Generator", version="1.0.0")
11
+
12
+ # Enable CORS for all origins
13
+ app.add_middleware(
14
+ CORSMiddleware,
15
+ allow_origins=["*"], # Allows all origins
16
+ allow_credentials=True,
17
+ allow_methods=["*"], # Allows all methods
18
+ allow_headers=["*"], # Allows all headers
19
+ )
20
+
21
+ # ---------- Config ----------
22
+ MODEL_NAME = os.getenv("MODEL_NAME", "ncfrey/ChemGPT-4.7M")
23
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
24
+ SPECIAL_TOKENS = {"[CLS]", "[SEP]", "[PAD]", "[UNK]", "[BOS]", "[EOS]", "[MASK]"}
25
+
26
+ # ---------- Modelo global (cargado una vez) ----------
27
+ tokenizer = None
28
+ model = None
29
+
30
+ @app.on_event("startup")
31
+ def load_model():
32
+ global tokenizer, model
33
+ try:
34
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
35
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
36
+ model.to(DEVICE)
37
+ model.eval()
38
+ except Exception as e:
39
+ raise RuntimeError(f"No se pudo cargar el modelo '{MODEL_NAME}': {e}")
40
+
41
+ # ---------- Utilidades de decodificación/postpro ----------
42
+ def decodificar_tokens(tokens):
43
+ mol = []
44
+ for tok in tokens:
45
+ if tok in SPECIAL_TOKENS:
46
+ continue
47
+ if tok.startswith("[") and tok.endswith("]"):
48
+ contenido = tok[1:-1]
49
+ if re.match(r'^[A-Za-z0-9@=#+\\/-]+$', contenido):
50
+ mol.append(contenido)
51
+ else:
52
+ mol.append(tok)
53
+ else:
54
+ mol.append(tok)
55
+ return "".join(mol)
56
+
57
+ def postprocesar_smiles(tokens_string: str) -> str:
58
+ pattern = re.compile(r'\[.*?\]')
59
+ tokens = pattern.split(tokens_string)
60
+ matches = pattern.findall(tokens_string)
61
+
62
+ result = []
63
+ branch_stack = []
64
+ ring_open = {}
65
+
66
+ for i in range(len(tokens)):
67
+ result.append(tokens[i])
68
+
69
+ if i < len(matches):
70
+ tok = matches[i]
71
+
72
+ if tok.startswith("[Branch"):
73
+ result.append("(")
74
+ branch_stack.append(")")
75
+
76
+ elif tok.startswith("[Ring"):
77
+ nums = re.findall(r'\d+', tok)
78
+ if nums:
79
+ n = nums[0]
80
+ if n not in ring_open:
81
+ ring_open[n] = True
82
+ else:
83
+ del ring_open[n]
84
+ result.append(n)
85
+
86
+ else:
87
+ result.append(tok)
88
+
89
+ while branch_stack:
90
+ result.append(branch_stack.pop())
91
+
92
+ return "".join(result)
93
+
94
+ # ---------- Tipos de request/response ----------
95
+ class GenerateRequest(BaseModel):
96
+ input_text: str
97
+ max_length: Optional[int] = 60
98
+ top_k: Optional[int] = 50
99
+ top_p: Optional[float] = 0.95
100
+ temperature: Optional[float] = 1.0
101
+
102
+ class GenerateResponse(BaseModel):
103
+ raw_tokens_string: str
104
+ smiles_postprocesado: str
105
+
106
+ # ---------- Rutas ----------
107
+ @app.get("/health")
108
+ def health():
109
+ return {"status": "ok", "device": DEVICE, "model": MODEL_NAME}
110
+
111
+ @app.post("/generate", response_model=GenerateResponse)
112
+ def generate(req: GenerateRequest):
113
+ if tokenizer is None or model is None:
114
+ raise HTTPException(status_code=500, detail="Modelo no inicializado")
115
+
116
+ try:
117
+ inputs = tokenizer(req.input_text, return_tensors="pt").to(DEVICE)
118
+ # Preferimos el token [EOS] si existe, si no el por defecto del tokenizer
119
+ eos_id = tokenizer.convert_tokens_to_ids("[EOS]")
120
+ if eos_id is None or eos_id == tokenizer.unk_token_id:
121
+ eos_id = tokenizer.eos_token_id
122
+
123
+ with torch.no_grad():
124
+ outputs = model.generate(
125
+ inputs["input_ids"],
126
+ max_length=req.max_length,
127
+ do_sample=True,
128
+ top_k=req.top_k,
129
+ top_p=req.top_p,
130
+ temperature=req.temperature,
131
+ eos_token_id=eos_id,
132
+ )
133
+
134
+ tokens = tokenizer.convert_ids_to_tokens(outputs[0])
135
+ tokens_string = decodificar_tokens(tokens)
136
+ smiles = postprocesar_smiles(tokens_string)
137
+ smiles_final = smiles.replace("Ring", "")
138
+
139
+ return GenerateResponse(
140
+ raw_tokens_string=tokens_string,
141
+ smiles_postprocesado=smiles
142
+ )
143
+ except Exception as e:
144
+ raise HTTPException(status_code=500, detail=f"Error generando SMILES: {str(e)}")
145
+
requirements.txt CHANGED
@@ -1,16 +1,16 @@
1
- --extra-index-url https://download.pytorch.org/whl/cpu
2
- torch==2.3.1 # wheel CPU estable y liviana
3
- transformers==4.46.3
4
- tokenizers==0.20.1 # wheel precompilada (evita compilar Rust)
5
- safetensors==0.4.5
6
- huggingface_hub==0.24.6
7
-
8
- fastapi==0.115.2
9
- uvicorn[standard]==0.30.6
10
- pydantic==2.7.4
11
- python-dotenv==1.0.1
12
- httpx==0.28.1 # si lo usás
13
- pytest==8.4.2
14
- pytest-asyncio==1.2.0
15
- gradio==4.44.0
16
- numpy==1.26.4
 
1
+ --extra-index-url https://download.pytorch.org/whl/cpu
2
+ torch==2.3.1 # wheel CPU estable y liviana
3
+ transformers==4.46.3
4
+ tokenizers==0.20.1 # wheel precompilada (evita compilar Rust)
5
+ safetensors==0.4.5
6
+ huggingface_hub==0.24.6
7
+
8
+ fastapi==0.115.2
9
+ uvicorn[standard]==0.30.6
10
+ pydantic==2.7.4
11
+ python-dotenv==1.0.1
12
+ httpx==0.28.1 # si lo usás
13
+ pytest==8.4.2
14
+ pytest-asyncio==1.2.0
15
+ gradio==4.44.0
16
+ numpy==1.26.4
test_local.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+
3
+ url = "http://127.0.0.1:8000/generate"
4
+ payload = {
5
+ "input_text": "CCO",
6
+ "max_length": 60,
7
+ "top_k": 50,
8
+ "top_p": 0.95,
9
+ "temperature": 1.0
10
+ }
11
+
12
+ response = requests.post(url, json=payload)
13
+ print(response.status_code)
14
+ print(response.json())