qrit-api / app.py
samdak93's picture
chnage dockerfile
9240df3
import os
# Set cache directories to use /tmp which is typically writable
os.environ['TRANSFORMERS_CACHE'] = '/tmp/huggingface'
os.environ['HF_HOME'] = '/tmp/huggingface'
from flask import Flask, request, jsonify
from flask_cors import CORS
from transformers import pipeline
import re
import logging
from datetime import datetime
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = Flask(__name__)
CORS(app)
# Model loading
MODEL_PATH = os.environ.get('MODEL_PATH', 'samdak93/qrit-2')
generator = None
def load_model():
global generator
logger.info(f"Loading model from {MODEL_PATH}")
try:
# Use Hugging Face pipeline for text generation
generator = pipeline("text-generation", model=MODEL_PATH)
logger.info("Model loaded successfully")
except Exception as e:
logger.error(f"Error loading custom model: {e}")
logger.info("Falling back to default GPT-2")
# Fallback to default GPT-2 if model loading fails
generator = pipeline("text-generation", model="gpt2")
logger.info("Default GPT-2 model loaded")
def generate_recipe(goal_prompt, max_length=300, temperature=0.7):
global generator
if generator is None:
load_model()
full_prompt = f"### Goal: {goal_prompt}\n### Recipe: "
# Generate recipe using the Hugging Face pipeline
try:
generated_text = generator(full_prompt, max_length=max_length, temperature=temperature, num_return_sequences=1)[0]['generated_text']
recipe_part = generated_text[len(full_prompt):]
if "### Goal:" in recipe_part:
recipe_part = recipe_part.split("### Goal:")[0].strip()
return full_prompt + recipe_part
except Exception as e:
logger.error(f"Error generating recipe: {str(e)}")
return str(e)
def parse_recipe(recipe_text):
recipe_data = {
"title": "",
"ingredients": [],
"instructions": [],
"nutrition": {}
}
title_match = re.search(r"Title: (.*?)(?:\n|$)", recipe_text)
if title_match:
recipe_data["title"] = title_match.group(1).strip()
ingredients_section = re.search(r"Ingredients:(.*?)Instructions:", recipe_text, re.DOTALL)
if ingredients_section:
ingredients_text = ingredients_section.group(1).strip()
ingredients_list = [item.strip().lstrip('- ') for item in ingredients_text.split('\n') if item.strip()]
recipe_data["ingredients"] = ingredients_list
instructions_section = re.search(r"Instructions:(.*?)(?:Nutrition:|$)", recipe_text, re.DOTALL)
if instructions_section:
instructions_text = instructions_section.group(1).strip()
instructions_list = []
for line in instructions_text.split('\n'):
line = line.strip()
if line:
line = re.sub(r"^\d+\.\s", "", line)
instructions_list.append(line.strip())
recipe_data["instructions"] = instructions_list
nutrition_section = re.search(r"Nutrition: (.*?)(?:\n|$)", recipe_text)
if nutrition_section:
nutrition_text = nutrition_section.group(1).strip()
for label in ["calories", "sugar", "protein", "fat"]:
match = re.search(rf"(\d+)\s*(?:g\s*)?{label}", nutrition_text)
if match:
recipe_data["nutrition"][label] = int(match.group(1))
return recipe_data
@app.route('/', methods=['GET'])
def home():
return jsonify({
"name": "Recipe Generation API",
"version": "1.0.0",
"endpoints": {
"/api/healthcheck": "Check API status",
"/api/generate": "Generate a recipe",
"/api/parse": "Parse a recipe text"
}
})
@app.route('/api/healthcheck', methods=['GET'])
def healthcheck():
return jsonify({
"status": "ok",
"message": "Recipe API is running",
"timestamp": str(datetime.now())
})
@app.route('/api/generate', methods=['POST'])
def api_generate_recipe():
data = request.json
if not data or 'goal' not in data:
return jsonify({"error": "Missing 'goal' parameter"}), 400
try:
goal = data['goal']
max_length = data.get('max_length', 300)
temperature = data.get('temperature', 0.7)
recipe_text = generate_recipe(goal, max_length, temperature)
return jsonify({"goal": goal, "recipe_text": recipe_text})
except Exception as e:
logger.error(f"Error generating recipe: {str(e)}")
return jsonify({"error": str(e)}), 500
@app.route('/api/parse', methods=['POST'])
def api_parse_recipe():
data = request.json
if not data or 'recipe_text' not in data:
return jsonify({"error": "Missing 'recipe_text' parameter"}), 400
try:
return jsonify(parse_recipe(data['recipe_text']))
except Exception as e:
return jsonify({"error": str(e)}), 500
if __name__ == '__main__':
# Create cache directory if it doesn't exist
os.makedirs(os.environ['TRANSFORMERS_CACHE'], exist_ok=True)
load_model()
port = int(os.environ.get('PORT', 7860))
app.run(host='0.0.0.0', port=port)