Spaces:
Sleeping
Sleeping
Vaishnav14220
commited on
Commit
·
3827772
1
Parent(s):
65371c6
Add Gemini AI integration for fixing reaction data with API key input and fix button
Browse files- app.py +97 -4
- requirements.txt +2 -1
app.py
CHANGED
|
@@ -6,7 +6,10 @@ from rdkit import Chem
|
|
| 6 |
from rdkit.Chem import AllChem, Draw
|
| 7 |
import io
|
| 8 |
import tempfile
|
|
|
|
| 9 |
import os
|
|
|
|
|
|
|
| 10 |
from reportlab.pdfgen import canvas
|
| 11 |
from reportlab.lib.pagesizes import letter
|
| 12 |
from reportlab.lib.styles import getSampleStyleSheet
|
|
@@ -205,6 +208,85 @@ def get_autocomplete_reactions(query):
|
|
| 205 |
matches = process.extract(query, reaction_names, limit=10)
|
| 206 |
return [m[0] for m in matches if m[1] > 60]
|
| 207 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
def get_autocomplete_reactants(query):
|
| 209 |
if not query:
|
| 210 |
return unique_reactants[:10]
|
|
@@ -252,15 +334,27 @@ with gr.Blocks(title="Organic Reactions Search") as demo:
|
|
| 252 |
pdf_btn.click(generate_all_reactions_pdf, outputs=pdf_output)
|
| 253 |
|
| 254 |
with gr.Tab("View All Reactions (Table)"):
|
| 255 |
-
gr.Markdown("Browse all 828 reactions in a tabular format.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
|
| 257 |
-
# Create HTML table
|
| 258 |
def create_reactions_table():
|
| 259 |
html = """
|
| 260 |
<table style="width:100%; border-collapse: collapse;">
|
| 261 |
<thead>
|
| 262 |
<tr style="background-color: #f2f2f2;">
|
| 263 |
-
<th style="border: 1px solid #ddd; padding: 8px;">Action</th>
|
| 264 |
<th style="border: 1px solid #ddd; padding: 8px;">Reaction Name</th>
|
| 265 |
<th style="border: 1px solid #ddd; padding: 8px;">Reactants</th>
|
| 266 |
<th style="border: 1px solid #ddd; padding: 8px;">Reactants SMILES</th>
|
|
@@ -289,7 +383,6 @@ with gr.Blocks(title="Organic Reactions Search") as demo:
|
|
| 289 |
|
| 290 |
html += f"""
|
| 291 |
<tr>
|
| 292 |
-
<td style="border: 1px solid #ddd; padding: 8px;"><button onclick="alert('Fix with AI functionality for {reaction_name} would be implemented here')">Fix with AI</button></td>
|
| 293 |
<td style="border: 1px solid #ddd; padding: 8px;">{reaction_name}</td>
|
| 294 |
<td style="border: 1px solid #ddd; padding: 8px;">{reactants}</td>
|
| 295 |
<td style="border: 1px solid #ddd; padding: 8px; font-family: monospace; font-size: 12px;">{reactants_smiles}</td>
|
|
|
|
| 6 |
from rdkit.Chem import AllChem, Draw
|
| 7 |
import io
|
| 8 |
import tempfile
|
| 9 |
+
import base64
|
| 10 |
import os
|
| 11 |
+
from google import genai
|
| 12 |
+
from google.genai import types
|
| 13 |
from reportlab.pdfgen import canvas
|
| 14 |
from reportlab.lib.pagesizes import letter
|
| 15 |
from reportlab.lib.styles import getSampleStyleSheet
|
|
|
|
| 208 |
matches = process.extract(query, reaction_names, limit=10)
|
| 209 |
return [m[0] for m in matches if m[1] > 60]
|
| 210 |
|
| 211 |
+
def fix_reaction_with_gemini(reaction_name, api_key):
|
| 212 |
+
if not api_key:
|
| 213 |
+
return "Please provide a Gemini API key."
|
| 214 |
+
|
| 215 |
+
try:
|
| 216 |
+
client = genai.Client(api_key=api_key)
|
| 217 |
+
|
| 218 |
+
prompt = f"""Please provide detailed information about the organic reaction named "{reaction_name}".
|
| 219 |
+
Include the correct reaction name, reactants, reagents, products, byproducts, reaction conditions, mechanism, and description.
|
| 220 |
+
Make sure to provide accurate chemical information."""
|
| 221 |
+
|
| 222 |
+
contents = [
|
| 223 |
+
types.Content(
|
| 224 |
+
role="user",
|
| 225 |
+
parts=[types.Part.from_text(text=prompt)],
|
| 226 |
+
),
|
| 227 |
+
]
|
| 228 |
+
|
| 229 |
+
generate_content_config = types.GenerateContentConfig(
|
| 230 |
+
thinking_config=types.ThinkingConfig(thinking_budget=-1),
|
| 231 |
+
response_mime_type="application/json",
|
| 232 |
+
response_schema=genai.types.Schema(
|
| 233 |
+
type=genai.types.Type.OBJECT,
|
| 234 |
+
required=["reaction name", "reactants", "reagents", "products", "byproducts", "conditions", "mechanism", "description"],
|
| 235 |
+
properties={
|
| 236 |
+
"reaction name": genai.types.Schema(type=genai.types.Type.STRING),
|
| 237 |
+
"reactants": genai.types.Schema(
|
| 238 |
+
type=genai.types.Type.ARRAY,
|
| 239 |
+
items=genai.types.Schema(type=genai.types.Type.STRING),
|
| 240 |
+
),
|
| 241 |
+
"reagents": genai.types.Schema(
|
| 242 |
+
type=genai.types.Type.ARRAY,
|
| 243 |
+
items=genai.types.Schema(type=genai.types.Type.STRING),
|
| 244 |
+
),
|
| 245 |
+
"products": genai.types.Schema(
|
| 246 |
+
type=genai.types.Type.ARRAY,
|
| 247 |
+
items=genai.types.Schema(type=genai.types.Type.STRING),
|
| 248 |
+
),
|
| 249 |
+
"byproducts": genai.types.Schema(
|
| 250 |
+
type=genai.types.Type.ARRAY,
|
| 251 |
+
items=genai.types.Schema(type=genai.types.Type.STRING),
|
| 252 |
+
),
|
| 253 |
+
"conditions": genai.types.Schema(type=genai.types.Type.STRING),
|
| 254 |
+
"mechanism": genai.types.Schema(type=genai.types.Type.STRING),
|
| 255 |
+
"description": genai.types.Schema(type=genai.types.Type.STRING),
|
| 256 |
+
},
|
| 257 |
+
),
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
response_text = ""
|
| 261 |
+
for chunk in client.models.generate_content_stream(
|
| 262 |
+
model="gemini-2.0-flash-exp",
|
| 263 |
+
contents=contents,
|
| 264 |
+
config=generate_content_config,
|
| 265 |
+
):
|
| 266 |
+
response_text += chunk.text
|
| 267 |
+
|
| 268 |
+
# Parse the JSON response
|
| 269 |
+
import json
|
| 270 |
+
gemini_data = json.loads(response_text)
|
| 271 |
+
|
| 272 |
+
# Update the global df (this won't persist, but shows the updated data)
|
| 273 |
+
global df
|
| 274 |
+
mask = df['corrected_name'].str.lower() == reaction_name.lower()
|
| 275 |
+
if mask.any():
|
| 276 |
+
idx = df[mask].index[0]
|
| 277 |
+
df.at[idx, 'corrected_name'] = gemini_data.get('reaction name', reaction_name)
|
| 278 |
+
df.at[idx, 'general_reactants'] = ', '.join(gemini_data.get('reactants', []))
|
| 279 |
+
df.at[idx, 'general_reagents'] = ', '.join(gemini_data.get('reagents', []))
|
| 280 |
+
df.at[idx, 'general_products'] = ', '.join(gemini_data.get('products', []))
|
| 281 |
+
df.at[idx, 'description'] = gemini_data.get('description', df.at[idx, 'description'])
|
| 282 |
+
|
| 283 |
+
return f"Successfully updated reaction '{reaction_name}' with AI-generated data:\n\n**New Name:** {gemini_data.get('reaction name')}\n**Reactants:** {', '.join(gemini_data.get('reactants', []))}\n**Reagents:** {', '.join(gemini_data.get('reagents', []))}\n**Products:** {', '.join(gemini_data.get('products', []))}\n**Description:** {gemini_data.get('description', '')[:200]}..."
|
| 284 |
+
else:
|
| 285 |
+
return f"Reaction '{reaction_name}' not found in database."
|
| 286 |
+
|
| 287 |
+
except Exception as e:
|
| 288 |
+
return f"Error calling Gemini API: {str(e)}"
|
| 289 |
+
|
| 290 |
def get_autocomplete_reactants(query):
|
| 291 |
if not query:
|
| 292 |
return unique_reactants[:10]
|
|
|
|
| 334 |
pdf_btn.click(generate_all_reactions_pdf, outputs=pdf_output)
|
| 335 |
|
| 336 |
with gr.Tab("View All Reactions (Table)"):
|
| 337 |
+
gr.Markdown("Browse all 828 reactions in a tabular format. Use the AI Fix section below to improve reaction data.")
|
| 338 |
+
|
| 339 |
+
# AI Fix section
|
| 340 |
+
with gr.Row():
|
| 341 |
+
api_key_input = gr.Textbox(label="Gemini API Key", type="password", placeholder="Enter your Gemini API key")
|
| 342 |
+
reaction_to_fix = gr.Dropdown(label="Select Reaction to Fix", choices=reaction_names)
|
| 343 |
+
|
| 344 |
+
fix_button = gr.Button("Fix with AI")
|
| 345 |
+
ai_status = gr.Markdown(label="AI Fix Status")
|
| 346 |
+
|
| 347 |
+
fix_button.click(fix_reaction_with_gemini, inputs=[reaction_to_fix, api_key_input], outputs=ai_status)
|
| 348 |
+
|
| 349 |
+
gr.Markdown("---")
|
| 350 |
+
gr.Markdown("**Database Table:**")
|
| 351 |
|
| 352 |
+
# Create HTML table (read-only for browsing)
|
| 353 |
def create_reactions_table():
|
| 354 |
html = """
|
| 355 |
<table style="width:100%; border-collapse: collapse;">
|
| 356 |
<thead>
|
| 357 |
<tr style="background-color: #f2f2f2;">
|
|
|
|
| 358 |
<th style="border: 1px solid #ddd; padding: 8px;">Reaction Name</th>
|
| 359 |
<th style="border: 1px solid #ddd; padding: 8px;">Reactants</th>
|
| 360 |
<th style="border: 1px solid #ddd; padding: 8px;">Reactants SMILES</th>
|
|
|
|
| 383 |
|
| 384 |
html += f"""
|
| 385 |
<tr>
|
|
|
|
| 386 |
<td style="border: 1px solid #ddd; padding: 8px;">{reaction_name}</td>
|
| 387 |
<td style="border: 1px solid #ddd; padding: 8px;">{reactants}</td>
|
| 388 |
<td style="border: 1px solid #ddd; padding: 8px; font-family: monospace; font-size: 12px;">{reactants_smiles}</td>
|
requirements.txt
CHANGED
|
@@ -5,4 +5,5 @@ fuzzywuzzy
|
|
| 5 |
python-levenshtein
|
| 6 |
rdkit
|
| 7 |
reportlab
|
| 8 |
-
svglib
|
|
|
|
|
|
| 5 |
python-levenshtein
|
| 6 |
rdkit
|
| 7 |
reportlab
|
| 8 |
+
svglib
|
| 9 |
+
google-genai
|