cnn-rag / app.py
aryan14072001's picture
Update app.py
9a7671d verified
"""
Article Writer RAG - HuggingFace Spaces App
Combines RAG enrichment + Your fine-tuned model
"""
import gradio as gr
import os
from groq import Groq
from duckduckgo_search import DDGS
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# ============================================================================
# CONFIGURATION
# ============================================================================
# Groq API key from Space secrets
GROQ_API_KEY = os.environ.get("GROQ_API_KEY", "")
# Line 17 in app.py - UPDATE THIS:
MODEL_NAME = "aryan14072001/article-writer-rag" # ← Your actual model!
# ============================================================================
# LOAD MODEL
# ============================================================================
print("πŸ”„ Loading model...")
try:
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16,
device_map="auto",
low_cpu_mem_usage=True,
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print("βœ… Model loaded!")
MODEL_LOADED = True
except Exception as e:
print(f"⚠️ Model loading failed: {e}")
print("Running in demo mode...")
MODEL_LOADED = False
model = None
tokenizer = None
# ============================================================================
# RAG ENRICHMENT (Optional)
# ============================================================================
def enrich_notes(rough_notes):
"""Use Groq + DuckDuckGo to enrich notes"""
if not GROQ_API_KEY:
return "No enrichment (GROQ_API_KEY not set)"
try:
groq = Groq(api_key=GROQ_API_KEY)
search = DDGS()
# Step 1: Identify entities
response = groq.chat.completions.create(
model="llama-3.1-70b-versatile",
messages=[{
"role": "user",
"content": f"Extract 3-5 key entities from:\n\n{rough_notes}\n\nList as: 1. Entity1 2. Entity2..."
}],
temperature=0.1,
)
entities_text = response.choices[0].message.content
# Step 2: Search for entities
import re
entities = re.findall(r'\d+\.\s*(.+)', entities_text)
entities = [e.strip() for e in entities[:3]] # Max 3 searches
search_results = []
for entity in entities:
if entity:
try:
results = search.text(entity, max_results=2)
search_results.append(f"ENTITY: {entity}\nINFO: {results[0]['body'][:200]}")
except:
pass
if not search_results:
return "No background info found"
# Step 3: Compile facts
combined = "\n\n".join(search_results)
response = groq.chat.completions.create(
model="llama-3.1-70b-versatile",
messages=[{
"role": "user",
"content": f"Extract verified facts:\n\nNOTES: {rough_notes}\n\nSEARCH: {combined}\n\nFormat: VERIFIED BACKGROUND:\nβ€’ fact1\nβ€’ fact2"
}],
temperature=0.1,
)
return response.choices[0].message.content
except Exception as e:
return f"Enrichment failed: {str(e)}"
# ============================================================================
# ARTICLE GENERATION
# ============================================================================
def generate_article(rough_notes, use_rag, temperature, max_tokens):
"""Generate article from rough notes"""
if not MODEL_LOADED:
return "⚠️ Model not loaded. This is a demo placeholder.", "N/A"
# Step 1: Enrich (if enabled)
if use_rag:
enriched = enrich_notes(rough_notes)
else:
enriched = "RAG disabled"
# Step 2: Generate article
messages = [
{
"role": "system",
"content": "You are a professional journalist. Expand rough notes into articles. Use verified background if provided."
},
{
"role": "user",
"content": f"ROUGH NOTES:\n{rough_notes}\n\nVERIFIED BACKGROUND:\n{enriched}\n\nWrite article:"
}
]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=0.85,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
article = result.split("assistant")[-1].strip() if "assistant" in result.lower() else result[len(prompt):].strip()
return article, enriched
# ============================================================================
# GRADIO INTERFACE
# ============================================================================
with gr.Blocks(title="Article Writer RAG", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# πŸ“° AI Article Writer with RAG
Transform rough notes into professional articles.
**Features:**
- πŸ”¬ RAG enrichment (Groq + web search)
- ✍️ Fine-tuned article generation
- 🌐 API endpoint available
"""
)
with gr.Row():
with gr.Column():
rough_notes = gr.Textbox(
label="πŸ“ Rough Notes",
placeholder="β€’ Fact 1\nβ€’ Fact 2\nβ€’ Fact 3",
lines=10
)
with gr.Row():
use_rag = gr.Checkbox(label="πŸ”¬ Enable RAG", value=False) # Default off for free tier
temperature = gr.Slider(0.1, 1.0, 0.3, label="🌑️ Temperature")
max_tokens = gr.Slider(128, 512, 400, label="πŸ“ Max Tokens")
generate_btn = gr.Button("πŸš€ Generate Article", variant="primary")
with gr.Column():
article_output = gr.Textbox(label="πŸ“„ Article", lines=15)
enriched_output = gr.Textbox(label="πŸ” Enriched Context", lines=5)
gr.Examples(
examples=[
["β€’ Scientists testified\nβ€’ Dr. Herberman testimony\nβ€’ Study found 2x cancer risk", False, 0.3, 400],
["β€’ Tech company AI launch\nβ€’ CEO says revolutionary\nβ€’ $999 price", False, 0.3, 300],
],
inputs=[rough_notes, use_rag, temperature, max_tokens],
)
gr.Markdown(
"""
---
### πŸ”Œ API Usage
```python
import requests
response = requests.post(
"https://your-space-url.hf.space/api/predict",
json={
"data": [
"β€’ Your rough notes here", # rough_notes
False, # use_rag
0.3, # temperature
400 # max_tokens
]
}
)
article = response.json()["data"][0]
```
**Note:** RAG requires GROQ_API_KEY in Space secrets. CPU inference may be slow (20-60s). Upgrade to GPU for 50x speed boost!
"""
)
generate_btn.click(
fn=generate_article,
inputs=[rough_notes, use_rag, temperature, max_tokens],
outputs=[article_output, enriched_output]
)
# ============================================================================
# LAUNCH
# ============================================================================
if __name__ == "__main__":
demo.launch()