Sameer-Handsome173 commited on
Commit
7641778
·
verified ·
1 Parent(s): 4c32a55

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -73
app.py CHANGED
@@ -4,16 +4,17 @@
4
 
5
  from fastapi import FastAPI, UploadFile, File, HTTPException
6
  from fastapi.middleware.cors import CORSMiddleware
7
- from pydantic import BaseModel
8
  from langchain.chains import GraphCypherQAChain, LLMChain
9
  from langchain_community.graphs import Neo4jGraph
10
  from langchain_community.llms import HuggingFaceHub
11
  from langchain.text_splitter import RecursiveCharacterTextSplitter
12
- from langchain.prompts import PromptTemplate, ChatPromptTemplate
13
  from langchain.output_parsers import PydanticOutputParser
14
- from typing import List
15
  import os
16
  import json
 
17
  import uvicorn
18
 
19
  # ================================
@@ -47,6 +48,28 @@ llm = None
47
  qa_chain = None
48
  extraction_chain = None
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  # ================================
51
  # Pydantic Models for API
52
  # ================================
@@ -65,27 +88,31 @@ class QueryResponse(BaseModel):
65
  cypher_query: str = None
66
 
67
  # ================================
68
- # Prompt Templates
69
  # ================================
70
 
71
- # 1. Entity Extraction Prompt Template (Simplified for Qwen)
72
- ENTITY_EXTRACTION_TEMPLATE = """Extract entities and relationships from the text below.
 
 
 
 
 
73
 
74
  TEXT:
75
  {text}
76
 
77
- Extract:
78
- 1. ENTITIES: People, organizations, products, technologies, concepts
79
- 2. RELATIONSHIPS: How entities connect (CREATED, WORKS_AT, USES, etc.)
80
-
81
- Output ONLY this JSON format (no other text):
82
- {{"entities": [{{"name": "FastAPI", "type": "Technology", "description": "web framework"}}], "relationships": [{{"source": "Person", "target": "FastAPI", "type": "CREATED"}}]}}
83
 
84
- JSON:"""
85
 
86
  entity_extraction_prompt = PromptTemplate(
87
  input_variables=["text"],
88
- template=ENTITY_EXTRACTION_TEMPLATE
 
89
  )
90
 
91
  # 2. Cypher Generation Prompt Template
@@ -203,82 +230,107 @@ async def startup_event():
203
  # ================================
204
 
205
  def extract_entities_relationships(text_chunk):
206
- """Extract entities and relationships using LangChain prompt template"""
207
 
208
  try:
209
- # Use the extraction chain
210
- response = extraction_chain.run(text=text_chunk)
211
-
212
  print(f"\n{'='*60}")
213
- print("RAW LLM RESPONSE:")
214
- print(response)
215
- print('='*60)
216
-
217
- # Clean response
218
- response = response.strip()
219
 
220
- # Remove markdown code blocks if present
221
- if "```json" in response:
222
- response = response.split("```json")[1].split("```")[0]
223
- elif "```" in response:
224
- response = response.split("```")[1].split("```")[0]
225
-
226
- response = response.strip()
227
-
228
- # Find JSON object
229
- if "{" in response and "}" in response:
230
- start = response.find("{")
231
- end = response.rfind("}") + 1
232
- response = response[start:end]
233
 
234
- print(f"CLEANED JSON:")
235
- print(response)
236
  print('='*60)
237
 
238
- data = json.loads(response)
239
-
240
- print(f"PARSED DATA:")
241
- print(f"Entities: {len(data.get('entities', []))}")
242
- print(f"Relationships: {len(data.get('relationships', []))}")
243
-
244
- return data
245
-
246
- except json.JSONDecodeError as e:
247
- print(f"❌ JSON parsing error: {e}")
248
- print(f"Response was: {response[:500]}")
249
-
250
- # Fallback: Try to extract at least some basic entities
251
- return fallback_extraction(text_chunk)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
 
253
  except Exception as e:
254
- print(f"❌ Extraction error: {e}")
255
- return {"entities": [], "relationships": []}
 
256
 
257
  def fallback_extraction(text):
258
- """Simple fallback extraction using basic NLP"""
259
  print("⚠️ Using fallback extraction...")
260
 
261
- # Simple entity extraction - find capitalized words
262
- import re
263
- words = text.split()
264
-
265
  entities = []
266
- seen = set()
 
 
 
 
267
 
268
- for i, word in enumerate(words):
269
- # Find capitalized words (potential entities)
270
- if word[0].isupper() and len(word) > 2:
271
- clean_word = re.sub(r'[^\w\s]', '', word)
272
- if clean_word and clean_word not in seen:
273
- entities.append({
274
- "name": clean_word,
275
- "type": "Concept",
276
- "description": f"Extracted from: {' '.join(words[max(0,i-3):min(len(words),i+4)])}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
  })
278
- seen.add(clean_word)
279
 
280
- print(f"Fallback extracted {len(entities)} entities")
281
- return {"entities": entities[:20], "relationships": []}
282
 
283
  def add_to_graph(entities, relationships, doc_name):
284
  """Add entities and relationships to Neo4j with proper sanitization"""
 
4
 
5
  from fastapi import FastAPI, UploadFile, File, HTTPException
6
  from fastapi.middleware.cors import CORSMiddleware
7
+ from pydantic import BaseModel, Field
8
  from langchain.chains import GraphCypherQAChain, LLMChain
9
  from langchain_community.graphs import Neo4jGraph
10
  from langchain_community.llms import HuggingFaceHub
11
  from langchain.text_splitter import RecursiveCharacterTextSplitter
12
+ from langchain.prompts import PromptTemplate
13
  from langchain.output_parsers import PydanticOutputParser
14
+ from typing import List, Optional
15
  import os
16
  import json
17
+ import re
18
  import uvicorn
19
 
20
  # ================================
 
48
  qa_chain = None
49
  extraction_chain = None
50
 
51
+ # ================================
52
+ # Pydantic Models for Extraction
53
+ # ================================
54
+
55
+ class Entity(BaseModel):
56
+ """Single entity extracted from text"""
57
+ name: str = Field(description="The name of the entity")
58
+ type: str = Field(description="Type: Person, Organization, Product, Technology, Concept, Location")
59
+ description: Optional[str] = Field(default="", description="Brief description of the entity")
60
+
61
+ class Relationship(BaseModel):
62
+ """Relationship between two entities"""
63
+ source: str = Field(description="Source entity name")
64
+ target: str = Field(description="Target entity name")
65
+ type: str = Field(description="Relationship type in UPPER_SNAKE_CASE (e.g., CREATED, FOUNDED, USES)")
66
+ context: Optional[str] = Field(default="", description="Context of the relationship")
67
+
68
+ class ExtractionResult(BaseModel):
69
+ """Complete extraction result"""
70
+ entities: List[Entity] = Field(description="List of extracted entities")
71
+ relationships: List[Relationship] = Field(description="List of extracted relationships")
72
+
73
  # ================================
74
  # Pydantic Models for API
75
  # ================================
 
88
  cypher_query: str = None
89
 
90
  # ================================
91
+ # Prompt Templates with Pydantic Parser
92
  # ================================
93
 
94
+ # Create parser for structured output
95
+ extraction_parser = PydanticOutputParser(pydantic_object=ExtractionResult)
96
+
97
+ # 1. Entity Extraction Prompt with Pydantic
98
+ ENTITY_EXTRACTION_TEMPLATE = """Extract entities and relationships from the text.
99
+
100
+ {format_instructions}
101
 
102
  TEXT:
103
  {text}
104
 
105
+ Important:
106
+ - Extract people, organizations, products, technologies, concepts
107
+ - For relationships use: CREATED, FOUNDED, USES, BUILT_ON, WORKS_AT, CEO_OF, INTEGRATES_WITH
108
+ - Be specific and accurate
 
 
109
 
110
+ Your response:"""
111
 
112
  entity_extraction_prompt = PromptTemplate(
113
  input_variables=["text"],
114
+ template=ENTITY_EXTRACTION_TEMPLATE,
115
+ partial_variables={"format_instructions": extraction_parser.get_format_instructions()}
116
  )
117
 
118
  # 2. Cypher Generation Prompt Template
 
230
  # ================================
231
 
232
  def extract_entities_relationships(text_chunk):
233
+ """Extract entities and relationships using Pydantic structured output"""
234
 
235
  try:
 
 
 
236
  print(f"\n{'='*60}")
237
+ print(f"Processing chunk: {text_chunk[:100]}...")
 
 
 
 
 
238
 
239
+ # Use the extraction chain
240
+ response = extraction_chain.run(text=text_chunk)
 
 
 
 
 
 
 
 
 
 
 
241
 
242
+ print(f"RAW LLM RESPONSE:")
243
+ print(response[:500])
244
  print('='*60)
245
 
246
+ # Try to parse with Pydantic parser
247
+ try:
248
+ result = extraction_parser.parse(response)
249
+
250
+ entities = [e.dict() for e in result.entities]
251
+ relationships = [r.dict() for r in result.relationships]
252
+
253
+ print(f"✅ PARSED with Pydantic:")
254
+ print(f" Entities: {len(entities)}")
255
+ print(f" Relationships: {len(relationships)}")
256
+
257
+ return {"entities": entities, "relationships": relationships}
258
+
259
+ except Exception as parse_error:
260
+ print(f"⚠️ Pydantic parsing failed: {parse_error}")
261
+ print("Trying manual JSON extraction...")
262
+
263
+ # Fallback: Try manual JSON extraction
264
+ cleaned = response.strip()
265
+
266
+ # Remove markdown
267
+ if "```json" in cleaned:
268
+ cleaned = cleaned.split("```json")[1].split("```")[0]
269
+ elif "```" in cleaned:
270
+ cleaned = cleaned.split("```")[1].split("```")[0]
271
+
272
+ # Find JSON
273
+ if "{" in cleaned and "}" in cleaned:
274
+ start = cleaned.find("{")
275
+ end = cleaned.rfind("}") + 1
276
+ cleaned = cleaned[start:end]
277
+
278
+ data = json.loads(cleaned)
279
+ print(f"✅ Manual JSON parse successful: {len(data.get('entities', []))} entities")
280
+ return data
281
 
282
  except Exception as e:
283
+ print(f"❌ All parsing failed: {e}")
284
+ print("Using fallback extraction...")
285
+ return fallback_extraction(text_chunk)
286
 
287
  def fallback_extraction(text):
288
+ """Simple rule-based fallback extraction"""
289
  print("⚠️ Using fallback extraction...")
290
 
 
 
 
 
291
  entities = []
292
+ relationships = []
293
+ seen_entities = set()
294
+
295
+ # Split into sentences
296
+ sentences = [s.strip() for s in text.split('.') if s.strip()]
297
 
298
+ for sentence in sentences:
299
+ words = sentence.split()
300
+
301
+ # Extract capitalized words/phrases as entities
302
+ current_entity = []
303
+ for word in words:
304
+ clean = re.sub(r'[^\w\s]', '', word)
305
+ if clean and clean[0].isupper() and len(clean) > 2:
306
+ current_entity.append(clean)
307
+ elif current_entity:
308
+ entity_name = ' '.join(current_entity)
309
+ if entity_name not in seen_entities:
310
+ entities.append({
311
+ "name": entity_name,
312
+ "type": "Concept",
313
+ "description": sentence[:100]
314
+ })
315
+ seen_entities.add(entity_name)
316
+ current_entity = []
317
+
318
+ # Check for common relationship patterns
319
+ if ' created ' in sentence.lower() or ' developed ' in sentence.lower():
320
+ # Try to extract creator and creation
321
+ parts = re.split(r' created | developed ', sentence, flags=re.IGNORECASE)
322
+ if len(parts) == 2:
323
+ creator = parts[0].strip().split()[-1]
324
+ creation = parts[1].strip().split()[0]
325
+ relationships.append({
326
+ "source": creator,
327
+ "target": creation,
328
+ "type": "CREATED",
329
+ "context": sentence[:100]
330
  })
 
331
 
332
+ print(f"Fallback extracted: {len(entities)} entities, {len(relationships)} relationships")
333
+ return {"entities": entities[:15], "relationships": relationships[:10]}
334
 
335
  def add_to_graph(entities, relationships, doc_name):
336
  """Add entities and relationships to Neo4j with proper sanitization"""